sglang 0.3.4__py3-none-any.whl → 0.3.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. sglang/bench_latency.py +2 -1
  2. sglang/lang/chat_template.py +17 -0
  3. sglang/launch_server_llavavid.py +1 -1
  4. sglang/srt/configs/__init__.py +3 -0
  5. sglang/srt/configs/model_config.py +27 -2
  6. sglang/srt/configs/qwen2vl.py +133 -0
  7. sglang/srt/constrained/fsm_cache.py +10 -3
  8. sglang/srt/conversation.py +27 -0
  9. sglang/srt/hf_transformers_utils.py +16 -1
  10. sglang/srt/layers/attention/__init__.py +16 -5
  11. sglang/srt/layers/attention/double_sparsity_backend.py +22 -6
  12. sglang/srt/layers/attention/flashinfer_backend.py +174 -54
  13. sglang/srt/layers/attention/triton_backend.py +22 -6
  14. sglang/srt/layers/attention/triton_ops/prefill_attention.py +26 -4
  15. sglang/srt/layers/linear.py +89 -63
  16. sglang/srt/layers/logits_processor.py +5 -5
  17. sglang/srt/layers/rotary_embedding.py +112 -0
  18. sglang/srt/layers/sampler.py +51 -39
  19. sglang/srt/lora/lora.py +3 -1
  20. sglang/srt/managers/data_parallel_controller.py +1 -1
  21. sglang/srt/managers/detokenizer_manager.py +4 -0
  22. sglang/srt/managers/image_processor.py +186 -13
  23. sglang/srt/managers/io_struct.py +10 -0
  24. sglang/srt/managers/schedule_batch.py +238 -68
  25. sglang/srt/managers/scheduler.py +69 -50
  26. sglang/srt/managers/tokenizer_manager.py +24 -4
  27. sglang/srt/managers/tp_worker.py +26 -111
  28. sglang/srt/managers/tp_worker_overlap_thread.py +209 -0
  29. sglang/srt/mem_cache/memory_pool.py +56 -10
  30. sglang/srt/mem_cache/radix_cache.py +4 -3
  31. sglang/srt/model_executor/cuda_graph_runner.py +87 -28
  32. sglang/srt/model_executor/forward_batch_info.py +83 -3
  33. sglang/srt/model_executor/model_runner.py +32 -11
  34. sglang/srt/models/chatglm.py +3 -3
  35. sglang/srt/models/deepseek_v2.py +2 -2
  36. sglang/srt/models/mllama.py +1004 -0
  37. sglang/srt/models/qwen2_vl.py +724 -0
  38. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +6 -3
  39. sglang/srt/sampling/sampling_batch_info.py +13 -3
  40. sglang/srt/sampling/sampling_params.py +5 -7
  41. sglang/srt/server.py +12 -0
  42. sglang/srt/server_args.py +10 -0
  43. sglang/srt/utils.py +22 -0
  44. sglang/test/run_eval.py +2 -0
  45. sglang/test/runners.py +20 -1
  46. sglang/test/srt/sampling/penaltylib/utils.py +1 -0
  47. sglang/test/test_utils.py +100 -3
  48. sglang/version.py +1 -1
  49. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/METADATA +17 -18
  50. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/RECORD +53 -48
  51. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/LICENSE +0 -0
  52. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/WHEEL +0 -0
  53. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/top_level.txt +0 -0
@@ -23,17 +23,20 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
23
23
  - ScheduleBatch is managed by `scheduler.py::Scheduler`.
24
24
  It contains high-level scheduling data. Most of the data is on the CPU.
25
25
  - ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
26
+ It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
27
+ It will be transformed from CPU scheduler to GPU model runner.
26
28
  - ForwardBatch is managed by `model_runner.py::ModelRunner`.
27
29
  It contains low-level tensor data. Most of the data consists of GPU tensors.
28
30
  """
29
31
 
32
+ import dataclasses
30
33
  import logging
31
- from dataclasses import dataclass
32
34
  from typing import List, Optional, Tuple, Union
33
35
 
34
36
  import torch
35
37
 
36
38
  from sglang.global_config import global_config
39
+ from sglang.srt.configs.model_config import ModelConfig
37
40
  from sglang.srt.constrained import RegexGuide
38
41
  from sglang.srt.constrained.jump_forward import JumpForwardMap
39
42
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
@@ -114,38 +117,50 @@ class FINISH_ABORT(BaseFinishReason):
114
117
  }
115
118
 
116
119
 
117
- @dataclass
120
+ @dataclasses.dataclass
118
121
  class ImageInputs:
119
122
  """The image related inputs."""
120
123
 
121
124
  pixel_values: torch.Tensor
122
- image_hash: int
125
+ image_hashes: Optional[list] = None
123
126
  image_sizes: Optional[list] = None
124
127
  image_offsets: Optional[list] = None
125
128
  pad_values: Optional[list] = None
126
129
  modalities: Optional[list] = None
130
+ num_image_tokens: Optional[int] = None
127
131
 
128
132
  image_embeds: Optional[List[torch.Tensor]] = None
129
133
  aspect_ratio_ids: Optional[List[torch.Tensor]] = None
130
134
  aspect_ratio_mask: Optional[List[torch.Tensor]] = None
135
+ # QWen2-VL related
136
+ image_grid_thws: List[Tuple[int, int, int]] = None
131
137
 
132
138
  @staticmethod
133
139
  def from_dict(obj, vocab_size):
134
140
  # Use image hash as fake token_ids, which is then used for prefix matching
135
141
  ret = ImageInputs(
136
142
  pixel_values=obj["pixel_values"],
137
- image_hash=hash(tuple(obj["image_hashes"])),
143
+ image_hashes=hash(tuple(obj["image_hashes"])),
138
144
  )
139
- image_hash = ret.image_hash
145
+ image_hash = ret.image_hashes
140
146
  ret.pad_values = [
141
147
  (image_hash) % vocab_size,
142
148
  (image_hash >> 16) % vocab_size,
143
149
  (image_hash >> 32) % vocab_size,
144
150
  (image_hash >> 64) % vocab_size,
145
151
  ]
146
- ret.image_sizes = obj["image_sizes"]
147
- # Only when pixel values is not None we have modalities
148
- ret.modalities = obj["modalities"] or ["image"]
152
+
153
+ optional_args = [
154
+ "image_sizes",
155
+ "modalities",
156
+ "aspect_ratio_ids",
157
+ "aspect_ratio_mask",
158
+ "image_grid_thws",
159
+ ]
160
+ for arg in optional_args:
161
+ if arg in obj:
162
+ setattr(ret, arg, obj[arg])
163
+
149
164
  return ret
150
165
 
151
166
 
@@ -236,6 +251,9 @@ class Req:
236
251
  self.regex_fsm_state: int = 0
237
252
  self.jump_forward_map: JumpForwardMap = None
238
253
 
254
+ # For Qwen2-VL
255
+ self.mrope_position_delta = [] # use mutable object
256
+
239
257
  # whether request reached finished condition
240
258
  def finished(self) -> bool:
241
259
  return self.finished_reason is not None
@@ -316,15 +334,20 @@ class Req:
316
334
 
317
335
  last_token_id = self.output_ids[-1]
318
336
 
319
- matched_eos = last_token_id in self.sampling_params.stop_token_ids
337
+ matched_eos = False
320
338
 
339
+ # Check stop token ids
340
+ if self.sampling_params.stop_token_ids:
341
+ matched_eos = last_token_id in self.sampling_params.stop_token_ids
321
342
  if self.tokenizer is not None:
322
343
  matched_eos |= last_token_id == self.tokenizer.eos_token_id
323
-
344
+ if self.tokenizer.additional_stop_token_ids:
345
+ matched_eos |= last_token_id in self.tokenizer.additional_stop_token_ids
324
346
  if matched_eos and not self.sampling_params.ignore_eos:
325
347
  self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
326
348
  return
327
349
 
350
+ # Check stop strings
328
351
  if len(self.sampling_params.stop_strs) > 0:
329
352
  tail_str = self.tokenizer.decode(
330
353
  self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
@@ -399,7 +422,7 @@ class Req:
399
422
  bid = 0
400
423
 
401
424
 
402
- @dataclass
425
+ @dataclasses.dataclass
403
426
  class ScheduleBatch:
404
427
  """Store all inforamtion of a batch."""
405
428
 
@@ -409,6 +432,9 @@ class ScheduleBatch:
409
432
  token_to_kv_pool: BaseTokenToKVPool = None
410
433
  tree_cache: BasePrefixCache = None
411
434
 
435
+ # For utility
436
+ model_config: ModelConfig = None
437
+
412
438
  forward_mode: ForwardMode = None
413
439
  sampling_info: SamplingBatchInfo = None
414
440
 
@@ -416,10 +442,13 @@ class ScheduleBatch:
416
442
  input_ids: torch.Tensor = None
417
443
  req_pool_indices: torch.Tensor = None
418
444
  seq_lens: torch.Tensor = None
445
+ # The output locations of the KV cache
419
446
  out_cache_loc: torch.Tensor = None
420
-
421
447
  output_ids: torch.Tensor = None
422
448
 
449
+ # The sum of all sequence lengths
450
+ seq_lens_sum: int = None
451
+
423
452
  # For processing logprobs
424
453
  return_logprob: bool = False
425
454
  top_logprobs_nums: Optional[List[int]] = None
@@ -428,33 +457,42 @@ class ScheduleBatch:
428
457
  prefix_lens: List[int] = None
429
458
  extend_lens: List[int] = None
430
459
  extend_num_tokens: int = None
431
- running_bs: int = None
432
460
  decoding_reqs: List[Req] = None
433
461
 
462
+ # For encoder-decoder
463
+ encoder_cached: Optional[List[bool]] = None
464
+ encoder_lens: Optional[torch.Tensor] = None
465
+ encoder_lens_cpu: Optional[List[int]] = None
466
+ encoder_out_cache_loc: Optional[torch.Tensor] = None
467
+
434
468
  # Stream
435
469
  has_stream: bool = False
436
470
 
437
- # device
438
- device: str = "cuda"
439
-
440
471
  # Has regex
441
472
  has_regex: bool = False
442
473
 
443
- @classmethod
444
- def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
445
- return_logprob = any(req.return_logprob for req in reqs)
446
- has_stream = any(req.stream for req in reqs)
447
- has_regex = any(req.regex_fsm for req in reqs)
474
+ # device
475
+ device: str = "cuda"
448
476
 
477
+ @classmethod
478
+ def init_new(
479
+ cls,
480
+ reqs,
481
+ req_to_token_pool,
482
+ token_to_kv_pool,
483
+ tree_cache,
484
+ model_config,
485
+ ):
449
486
  return cls(
450
487
  reqs=reqs,
451
488
  req_to_token_pool=req_to_token_pool,
452
489
  token_to_kv_pool=token_to_kv_pool,
453
490
  tree_cache=tree_cache,
454
- return_logprob=return_logprob,
455
- has_stream=has_stream,
491
+ model_config=model_config,
492
+ return_logprob=any(req.return_logprob for req in reqs),
493
+ has_stream=any(req.stream for req in reqs),
494
+ has_regex=any(req.regex_fsm for req in reqs),
456
495
  device=req_to_token_pool.device,
457
- has_regex=has_regex,
458
496
  )
459
497
 
460
498
  def batch_size(self):
@@ -481,14 +519,90 @@ class ScheduleBatch:
481
519
  out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
482
520
 
483
521
  if out_cache_loc is None:
484
- logger.error("Prefill out of memory. Try to lower your batch size.")
522
+ phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
523
+ logger.error(
524
+ f"{phase_str} out of memory. Try to lower your batch size.\n"
525
+ f"Try to allocate {num_tokens} tokens.\n"
526
+ f"Avaliable tokens: {self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()}\n"
527
+ )
485
528
  if self.tree_cache is not None:
486
529
  self.tree_cache.pretty_print()
487
530
  exit(1)
488
531
 
489
532
  return out_cache_loc
490
533
 
491
- def prepare_for_extend(self, vocab_size: int):
534
+ def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
535
+ self.encoder_lens_cpu = []
536
+ self.encoder_cached = []
537
+
538
+ for req in self.reqs:
539
+ im = req.image_inputs
540
+ if im is None or im.num_image_tokens is None:
541
+ # No image input
542
+ self.encoder_lens_cpu.append(0)
543
+ self.encoder_cached.append(True)
544
+ else:
545
+ self.encoder_lens_cpu.append(im.num_image_tokens)
546
+ self.encoder_cached.append(
547
+ self.forward_mode.is_decode()
548
+ or len(req.prefix_indices) >= im.num_image_tokens
549
+ )
550
+
551
+ self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int32).to(
552
+ self.device, non_blocking=True
553
+ )
554
+
555
+ # Strip encoder infos
556
+ pt = 0
557
+ decoder_out_cache_loc = []
558
+ encoder_out_cache_loc = []
559
+ for i, req in enumerate(self.reqs):
560
+ encoder_len = self.encoder_lens_cpu[i]
561
+ seq_lens[i] -= encoder_len
562
+
563
+ if len(req.prefix_indices) < encoder_len:
564
+ # NOTE: the encoder part should considered as a whole
565
+ assert len(req.prefix_indices) == 0
566
+ input_ids[i] = input_ids[i][encoder_len:]
567
+ encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
568
+ decoder_out_cache_loc.append(
569
+ self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
570
+ )
571
+ self.extend_lens[i] -= encoder_len
572
+ self.extend_num_tokens -= encoder_len
573
+ else:
574
+ decoder_out_cache_loc.append(
575
+ self.out_cache_loc[pt : pt + req.extend_input_len]
576
+ )
577
+ self.prefix_lens[i] -= encoder_len
578
+
579
+ pt += req.extend_input_len
580
+
581
+ # Reassign
582
+ self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
583
+ self.device, non_blocking=True
584
+ )
585
+ self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
586
+ self.device, non_blocking=True
587
+ )
588
+
589
+ if not decoder_out_cache_loc:
590
+ self.out_cache_loc = torch.empty(0, dtype=torch.int32).to(
591
+ self.device, non_blocking=True
592
+ )
593
+ else:
594
+ self.out_cache_loc = torch.cat(decoder_out_cache_loc)
595
+
596
+ if not encoder_out_cache_loc:
597
+ self.encoder_out_cache_loc = torch.empty(0, dtype=torch.int32).to(
598
+ self.device, non_blocking=True
599
+ )
600
+ else:
601
+ self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)
602
+
603
+ assert len(self.out_cache_loc) == self.extend_num_tokens
604
+
605
+ def prepare_for_extend(self):
492
606
  self.forward_mode = ForwardMode.EXTEND
493
607
 
494
608
  bs = len(self.reqs)
@@ -516,12 +630,12 @@ class ScheduleBatch:
516
630
  assert seq_len - pre_len == req.extend_input_len
517
631
 
518
632
  if pre_len > 0:
519
- self.req_to_token_pool.req_to_token[req.req_pool_idx, :pre_len] = (
520
- req.prefix_indices
633
+ self.req_to_token_pool.write(
634
+ (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
521
635
  )
522
-
523
- self.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
524
- out_cache_loc[pt : pt + req.extend_input_len]
636
+ self.req_to_token_pool.write(
637
+ (req.req_pool_idx, slice(pre_len, seq_len)),
638
+ out_cache_loc[pt : pt + req.extend_input_len],
525
639
  )
526
640
 
527
641
  # Compute the relative logprob_start_len in an extend batch
@@ -546,16 +660,23 @@ class ScheduleBatch:
546
660
  self.device, non_blocking=True
547
661
  )
548
662
 
549
- self.extend_num_tokens = extend_num_tokens
550
663
  self.out_cache_loc = out_cache_loc
664
+
665
+ self.seq_lens_sum = sum(seq_lens)
551
666
  if self.return_logprob:
552
667
  self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
668
+ self.extend_num_tokens = extend_num_tokens
553
669
  self.prefix_lens = [len(r.prefix_indices) for r in reqs]
554
670
  self.extend_lens = [r.extend_input_len for r in reqs]
555
671
  self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
556
672
 
673
+ if self.model_config.is_encoder_decoder:
674
+ self.prepare_encoder_info_extend(input_ids, seq_lens)
675
+
557
676
  self.sampling_info = SamplingBatchInfo.from_schedule_batch(
558
- self, vocab_size, global_server_args_dict["disable_penalizer"]
677
+ self,
678
+ self.model_config.vocab_size,
679
+ global_server_args_dict["disable_penalizer"],
559
680
  )
560
681
 
561
682
  def mix_with_running(self, running_batch: "ScheduleBatch"):
@@ -568,12 +689,11 @@ class ScheduleBatch:
568
689
 
569
690
  input_ids = torch.cat([self.input_ids, running_batch.input_ids])
570
691
  out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
571
- extend_num_tokens = self.extend_num_tokens + running_bs
572
692
 
573
693
  self.merge_batch(running_batch)
574
694
  self.input_ids = input_ids
575
695
  self.out_cache_loc = out_cache_loc
576
- self.extend_num_tokens = extend_num_tokens
696
+ self.extend_num_tokens += running_bs
577
697
 
578
698
  # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
579
699
  self.prefix_lens.extend(
@@ -631,8 +751,8 @@ class ScheduleBatch:
631
751
 
632
752
  if isinstance(self.tree_cache, ChunkCache):
633
753
  # ChunkCache does not have eviction
634
- token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx][
635
- : seq_lens_cpu[idx]
754
+ token_indices = self.req_to_token_pool.req_to_token[
755
+ req.req_pool_idx, : seq_lens_cpu[idx]
636
756
  ]
637
757
  self.token_to_kv_pool.free(token_indices)
638
758
  self.req_to_token_pool.free(req.req_pool_idx)
@@ -640,8 +760,8 @@ class ScheduleBatch:
640
760
  else:
641
761
  # TODO: apply more fine-grained retraction
642
762
  last_uncached_pos = len(req.prefix_indices)
643
- token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx][
644
- last_uncached_pos : seq_lens_cpu[idx]
763
+ token_indices = self.req_to_token_pool.req_to_token[
764
+ req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
645
765
  ]
646
766
  self.token_to_kv_pool.free(token_indices)
647
767
  self.req_to_token_pool.free(req.req_pool_idx)
@@ -746,7 +866,11 @@ class ScheduleBatch:
746
866
 
747
867
  return jump_forward_reqs
748
868
 
749
- def prepare_for_decode(self):
869
+ def prepare_encoder_info_decode(self):
870
+ # Reset the encoder cached status
871
+ self.encoder_cached = [True] * len(self.reqs)
872
+
873
+ def prepare_for_decode(self, enable_overlap: bool = False):
750
874
  self.forward_mode = ForwardMode.DECODE
751
875
 
752
876
  self.input_ids = self.output_ids
@@ -760,10 +884,25 @@ class ScheduleBatch:
760
884
  bs = len(self.reqs)
761
885
  self.out_cache_loc = self.alloc_token_slots(bs)
762
886
 
763
- self.req_to_token_pool.req_to_token[self.req_pool_indices, self.seq_lens] = (
764
- self.out_cache_loc
765
- )
766
- self.seq_lens.add_(1)
887
+ if self.model_config.is_encoder_decoder:
888
+ locs = self.encoder_lens + self.seq_lens
889
+ self.prepare_encoder_info_decode()
890
+ else:
891
+ locs = self.seq_lens
892
+
893
+ if enable_overlap:
894
+ # Do not use in-place operations in the overlap mode
895
+ self.req_to_token_pool.write(
896
+ (self.req_pool_indices, locs), self.out_cache_loc
897
+ )
898
+ self.seq_lens = self.seq_lens + 1
899
+ else:
900
+ # A faster in-place version
901
+ self.req_to_token_pool.write(
902
+ (self.req_pool_indices, locs), self.out_cache_loc
903
+ )
904
+ self.seq_lens.add_(1)
905
+ self.seq_lens_sum += bs
767
906
 
768
907
  def filter_batch(
769
908
  self,
@@ -787,6 +926,10 @@ class ScheduleBatch:
787
926
  # No need to filter
788
927
  return
789
928
 
929
+ if self.model_config.is_encoder_decoder:
930
+ self.encoder_lens = self.encoder_lens[keep_indices]
931
+ self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
932
+
790
933
  self.reqs = [self.reqs[i] for i in keep_indices]
791
934
  new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
792
935
  self.device, non_blocking=True
@@ -794,6 +937,7 @@ class ScheduleBatch:
794
937
  self.req_pool_indices = self.req_pool_indices[new_indices]
795
938
  self.seq_lens = self.seq_lens[new_indices]
796
939
  self.out_cache_loc = None
940
+ self.seq_lens_sum = self.seq_lens.sum().item()
797
941
  self.output_ids = self.output_ids[new_indices]
798
942
  self.return_logprob = any(req.return_logprob for req in self.reqs)
799
943
  if self.return_logprob:
@@ -812,11 +956,17 @@ class ScheduleBatch:
812
956
  # needs to be called with pre-merged Batch.reqs.
813
957
  self.sampling_info.merge_batch(other.sampling_info)
814
958
 
959
+ # Encoder-decoder infos
960
+ if self.model_config.is_encoder_decoder:
961
+ self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
962
+ self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
963
+
815
964
  self.req_pool_indices = torch.concat(
816
965
  [self.req_pool_indices, other.req_pool_indices]
817
966
  )
818
967
  self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
819
968
  self.out_cache_loc = None
969
+ self.seq_lens_sum += other.seq_lens_sum
820
970
  if self.output_ids is not None:
821
971
  self.output_ids = torch.concat([self.output_ids, other.output_ids])
822
972
  if self.return_logprob and other.return_logprob:
@@ -833,16 +983,12 @@ class ScheduleBatch:
833
983
 
834
984
  def get_model_worker_batch(self):
835
985
  if self.forward_mode.is_decode():
836
- extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = (
837
- image_inputs
838
- ) = None
986
+ extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
839
987
  else:
840
988
  extend_seq_lens = self.extend_lens
841
989
  extend_prefix_lens = self.prefix_lens
842
990
  extend_logprob_start_lens = self.extend_logprob_start_lens
843
- image_inputs = [r.image_inputs for r in self.reqs]
844
991
 
845
- lora_paths = [req.lora_path for req in self.reqs]
846
992
  if self.has_regex:
847
993
  self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs]
848
994
  self.sampling_info.regex_fsm_states = [
@@ -854,6 +1000,8 @@ class ScheduleBatch:
854
1000
  global bid
855
1001
  bid += 1
856
1002
 
1003
+ mrope_positions_delta = [req.mrope_position_delta for req in self.reqs]
1004
+
857
1005
  return ModelWorkerBatch(
858
1006
  bid=bid,
859
1007
  forward_mode=self.forward_mode,
@@ -861,19 +1009,29 @@ class ScheduleBatch:
861
1009
  req_pool_indices=self.req_pool_indices,
862
1010
  seq_lens=self.seq_lens,
863
1011
  out_cache_loc=self.out_cache_loc,
1012
+ seq_lens_sum=self.seq_lens_sum,
1013
+ req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
864
1014
  return_logprob=self.return_logprob,
865
1015
  top_logprobs_nums=self.top_logprobs_nums,
1016
+ extend_num_tokens=self.extend_num_tokens,
866
1017
  extend_seq_lens=extend_seq_lens,
867
1018
  extend_prefix_lens=extend_prefix_lens,
868
1019
  extend_logprob_start_lens=extend_logprob_start_lens,
869
- image_inputs=image_inputs,
870
- lora_paths=lora_paths,
1020
+ image_inputs=[r.image_inputs for r in self.reqs],
1021
+ encoder_cached=self.encoder_cached,
1022
+ encoder_lens=self.encoder_lens,
1023
+ encoder_lens_cpu=self.encoder_lens_cpu,
1024
+ encoder_out_cache_loc=self.encoder_out_cache_loc,
1025
+ lora_paths=[req.lora_path for req in self.reqs],
871
1026
  sampling_info=self.sampling_info,
1027
+ mrope_positions_delta=mrope_positions_delta,
872
1028
  )
873
1029
 
874
1030
  def copy(self):
1031
+ # Only contain fields that will be used by process_batch_result
875
1032
  return ScheduleBatch(
876
1033
  reqs=self.reqs,
1034
+ model_config=self.model_config,
877
1035
  forward_mode=self.forward_mode,
878
1036
  out_cache_loc=self.out_cache_loc,
879
1037
  return_logprob=self.return_logprob,
@@ -887,7 +1045,7 @@ class ScheduleBatch:
887
1045
  )
888
1046
 
889
1047
 
890
- @dataclass
1048
+ @dataclasses.dataclass
891
1049
  class ModelWorkerBatch:
892
1050
  # The batch id
893
1051
  bid: int
@@ -902,11 +1060,18 @@ class ModelWorkerBatch:
902
1060
  # The indices of output tokens in the token_to_kv_pool
903
1061
  out_cache_loc: torch.Tensor
904
1062
 
1063
+ # The sum of all sequence lengths
1064
+ seq_lens_sum: int
1065
+
1066
+ # The memory pool operation records
1067
+ req_to_token_pool_records: Optional[List[Tuple[Tuple, torch.Tensor]]]
1068
+
905
1069
  # For logprob
906
1070
  return_logprob: bool
907
1071
  top_logprobs_nums: Optional[List[int]]
908
1072
 
909
1073
  # For extend
1074
+ extend_num_tokens: Optional[int]
910
1075
  extend_seq_lens: Optional[List[int]]
911
1076
  extend_prefix_lens: Optional[List[int]]
912
1077
  extend_logprob_start_lens: Optional[List[int]]
@@ -914,26 +1079,31 @@ class ModelWorkerBatch:
914
1079
  # For multimodal
915
1080
  image_inputs: Optional[List[ImageInputs]]
916
1081
 
1082
+ # For encoder-decoder
1083
+ encoder_cached: Optional[List[bool]]
1084
+ encoder_lens: Optional[torch.Tensor]
1085
+ encoder_lens_cpu: Optional[List[int]]
1086
+ encoder_out_cache_loc: Optional[torch.Tensor]
1087
+
917
1088
  # For LoRA
918
1089
  lora_paths: Optional[List[str]]
919
1090
 
920
1091
  # Sampling info
921
1092
  sampling_info: SamplingBatchInfo
922
1093
 
1094
+ # For Qwen2-VL
1095
+ mrope_positions_delta: List[List[int]]
1096
+
923
1097
  def copy(self):
924
- return ModelWorkerBatch(
925
- bid=self.bid,
926
- forward_mode=self.forward_mode,
927
- input_ids=self.input_ids.clone(),
928
- req_pool_indices=self.req_pool_indices,
929
- seq_lens=self.seq_lens.clone(),
930
- out_cache_loc=self.out_cache_loc,
931
- return_logprob=self.return_logprob,
932
- top_logprobs_nums=self.top_logprobs_nums,
933
- extend_seq_lens=self.extend_seq_lens,
934
- extend_prefix_lens=self.extend_prefix_lens,
935
- extend_logprob_start_lens=self.extend_logprob_start_lens,
936
- image_inputs=self.image_inputs,
937
- lora_paths=self.lora_paths,
938
- sampling_info=self.sampling_info.copy(),
939
- )
1098
+ return dataclasses.replace(self, sampling_info=self.sampling_info.copy())
1099
+
1100
+ def to(self, device: str):
1101
+ self.input_ids = self.input_ids.to(device, non_blocking=True)
1102
+ self.req_pool_indices = self.req_pool_indices.to(device, non_blocking=True)
1103
+ self.seq_lens = self.seq_lens.to(device, non_blocking=True)
1104
+ self.out_cache_loc = self.out_cache_loc.to(device, non_blocking=True)
1105
+ self.req_to_token_pool_records = [
1106
+ (x, y.to(device, non_blocking=True))
1107
+ for x, y in self.req_to_token_pool_records
1108
+ ]
1109
+ self.sampling_info.to(device)