sglang 0.3.4__py3-none-any.whl → 0.3.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +2 -1
- sglang/lang/chat_template.py +17 -0
- sglang/launch_server_llavavid.py +1 -1
- sglang/srt/configs/__init__.py +3 -0
- sglang/srt/configs/model_config.py +27 -2
- sglang/srt/configs/qwen2vl.py +133 -0
- sglang/srt/constrained/fsm_cache.py +10 -3
- sglang/srt/conversation.py +27 -0
- sglang/srt/hf_transformers_utils.py +16 -1
- sglang/srt/layers/attention/__init__.py +16 -5
- sglang/srt/layers/attention/double_sparsity_backend.py +22 -6
- sglang/srt/layers/attention/flashinfer_backend.py +174 -54
- sglang/srt/layers/attention/triton_backend.py +22 -6
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +26 -4
- sglang/srt/layers/linear.py +89 -63
- sglang/srt/layers/logits_processor.py +5 -5
- sglang/srt/layers/rotary_embedding.py +112 -0
- sglang/srt/layers/sampler.py +51 -39
- sglang/srt/lora/lora.py +3 -1
- sglang/srt/managers/data_parallel_controller.py +1 -1
- sglang/srt/managers/detokenizer_manager.py +4 -0
- sglang/srt/managers/image_processor.py +186 -13
- sglang/srt/managers/io_struct.py +10 -0
- sglang/srt/managers/schedule_batch.py +238 -68
- sglang/srt/managers/scheduler.py +69 -50
- sglang/srt/managers/tokenizer_manager.py +24 -4
- sglang/srt/managers/tp_worker.py +26 -111
- sglang/srt/managers/tp_worker_overlap_thread.py +209 -0
- sglang/srt/mem_cache/memory_pool.py +56 -10
- sglang/srt/mem_cache/radix_cache.py +4 -3
- sglang/srt/model_executor/cuda_graph_runner.py +87 -28
- sglang/srt/model_executor/forward_batch_info.py +83 -3
- sglang/srt/model_executor/model_runner.py +32 -11
- sglang/srt/models/chatglm.py +3 -3
- sglang/srt/models/deepseek_v2.py +2 -2
- sglang/srt/models/mllama.py +1004 -0
- sglang/srt/models/qwen2_vl.py +724 -0
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +6 -3
- sglang/srt/sampling/sampling_batch_info.py +13 -3
- sglang/srt/sampling/sampling_params.py +5 -7
- sglang/srt/server.py +12 -0
- sglang/srt/server_args.py +10 -0
- sglang/srt/utils.py +22 -0
- sglang/test/run_eval.py +2 -0
- sglang/test/runners.py +20 -1
- sglang/test/srt/sampling/penaltylib/utils.py +1 -0
- sglang/test/test_utils.py +100 -3
- sglang/version.py +1 -1
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/METADATA +17 -18
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/RECORD +53 -48
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/LICENSE +0 -0
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/WHEEL +0 -0
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/top_level.txt +0 -0
@@ -23,17 +23,20 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
|
23
23
|
- ScheduleBatch is managed by `scheduler.py::Scheduler`.
|
24
24
|
It contains high-level scheduling data. Most of the data is on the CPU.
|
25
25
|
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
|
26
|
+
It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
|
27
|
+
It will be transformed from CPU scheduler to GPU model runner.
|
26
28
|
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
|
27
29
|
It contains low-level tensor data. Most of the data consists of GPU tensors.
|
28
30
|
"""
|
29
31
|
|
32
|
+
import dataclasses
|
30
33
|
import logging
|
31
|
-
from dataclasses import dataclass
|
32
34
|
from typing import List, Optional, Tuple, Union
|
33
35
|
|
34
36
|
import torch
|
35
37
|
|
36
38
|
from sglang.global_config import global_config
|
39
|
+
from sglang.srt.configs.model_config import ModelConfig
|
37
40
|
from sglang.srt.constrained import RegexGuide
|
38
41
|
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
39
42
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
@@ -114,38 +117,50 @@ class FINISH_ABORT(BaseFinishReason):
|
|
114
117
|
}
|
115
118
|
|
116
119
|
|
117
|
-
@dataclass
|
120
|
+
@dataclasses.dataclass
|
118
121
|
class ImageInputs:
|
119
122
|
"""The image related inputs."""
|
120
123
|
|
121
124
|
pixel_values: torch.Tensor
|
122
|
-
|
125
|
+
image_hashes: Optional[list] = None
|
123
126
|
image_sizes: Optional[list] = None
|
124
127
|
image_offsets: Optional[list] = None
|
125
128
|
pad_values: Optional[list] = None
|
126
129
|
modalities: Optional[list] = None
|
130
|
+
num_image_tokens: Optional[int] = None
|
127
131
|
|
128
132
|
image_embeds: Optional[List[torch.Tensor]] = None
|
129
133
|
aspect_ratio_ids: Optional[List[torch.Tensor]] = None
|
130
134
|
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
|
135
|
+
# QWen2-VL related
|
136
|
+
image_grid_thws: List[Tuple[int, int, int]] = None
|
131
137
|
|
132
138
|
@staticmethod
|
133
139
|
def from_dict(obj, vocab_size):
|
134
140
|
# Use image hash as fake token_ids, which is then used for prefix matching
|
135
141
|
ret = ImageInputs(
|
136
142
|
pixel_values=obj["pixel_values"],
|
137
|
-
|
143
|
+
image_hashes=hash(tuple(obj["image_hashes"])),
|
138
144
|
)
|
139
|
-
image_hash = ret.
|
145
|
+
image_hash = ret.image_hashes
|
140
146
|
ret.pad_values = [
|
141
147
|
(image_hash) % vocab_size,
|
142
148
|
(image_hash >> 16) % vocab_size,
|
143
149
|
(image_hash >> 32) % vocab_size,
|
144
150
|
(image_hash >> 64) % vocab_size,
|
145
151
|
]
|
146
|
-
|
147
|
-
|
148
|
-
|
152
|
+
|
153
|
+
optional_args = [
|
154
|
+
"image_sizes",
|
155
|
+
"modalities",
|
156
|
+
"aspect_ratio_ids",
|
157
|
+
"aspect_ratio_mask",
|
158
|
+
"image_grid_thws",
|
159
|
+
]
|
160
|
+
for arg in optional_args:
|
161
|
+
if arg in obj:
|
162
|
+
setattr(ret, arg, obj[arg])
|
163
|
+
|
149
164
|
return ret
|
150
165
|
|
151
166
|
|
@@ -236,6 +251,9 @@ class Req:
|
|
236
251
|
self.regex_fsm_state: int = 0
|
237
252
|
self.jump_forward_map: JumpForwardMap = None
|
238
253
|
|
254
|
+
# For Qwen2-VL
|
255
|
+
self.mrope_position_delta = [] # use mutable object
|
256
|
+
|
239
257
|
# whether request reached finished condition
|
240
258
|
def finished(self) -> bool:
|
241
259
|
return self.finished_reason is not None
|
@@ -316,15 +334,20 @@ class Req:
|
|
316
334
|
|
317
335
|
last_token_id = self.output_ids[-1]
|
318
336
|
|
319
|
-
matched_eos =
|
337
|
+
matched_eos = False
|
320
338
|
|
339
|
+
# Check stop token ids
|
340
|
+
if self.sampling_params.stop_token_ids:
|
341
|
+
matched_eos = last_token_id in self.sampling_params.stop_token_ids
|
321
342
|
if self.tokenizer is not None:
|
322
343
|
matched_eos |= last_token_id == self.tokenizer.eos_token_id
|
323
|
-
|
344
|
+
if self.tokenizer.additional_stop_token_ids:
|
345
|
+
matched_eos |= last_token_id in self.tokenizer.additional_stop_token_ids
|
324
346
|
if matched_eos and not self.sampling_params.ignore_eos:
|
325
347
|
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
|
326
348
|
return
|
327
349
|
|
350
|
+
# Check stop strings
|
328
351
|
if len(self.sampling_params.stop_strs) > 0:
|
329
352
|
tail_str = self.tokenizer.decode(
|
330
353
|
self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
|
@@ -399,7 +422,7 @@ class Req:
|
|
399
422
|
bid = 0
|
400
423
|
|
401
424
|
|
402
|
-
@dataclass
|
425
|
+
@dataclasses.dataclass
|
403
426
|
class ScheduleBatch:
|
404
427
|
"""Store all inforamtion of a batch."""
|
405
428
|
|
@@ -409,6 +432,9 @@ class ScheduleBatch:
|
|
409
432
|
token_to_kv_pool: BaseTokenToKVPool = None
|
410
433
|
tree_cache: BasePrefixCache = None
|
411
434
|
|
435
|
+
# For utility
|
436
|
+
model_config: ModelConfig = None
|
437
|
+
|
412
438
|
forward_mode: ForwardMode = None
|
413
439
|
sampling_info: SamplingBatchInfo = None
|
414
440
|
|
@@ -416,10 +442,13 @@ class ScheduleBatch:
|
|
416
442
|
input_ids: torch.Tensor = None
|
417
443
|
req_pool_indices: torch.Tensor = None
|
418
444
|
seq_lens: torch.Tensor = None
|
445
|
+
# The output locations of the KV cache
|
419
446
|
out_cache_loc: torch.Tensor = None
|
420
|
-
|
421
447
|
output_ids: torch.Tensor = None
|
422
448
|
|
449
|
+
# The sum of all sequence lengths
|
450
|
+
seq_lens_sum: int = None
|
451
|
+
|
423
452
|
# For processing logprobs
|
424
453
|
return_logprob: bool = False
|
425
454
|
top_logprobs_nums: Optional[List[int]] = None
|
@@ -428,33 +457,42 @@ class ScheduleBatch:
|
|
428
457
|
prefix_lens: List[int] = None
|
429
458
|
extend_lens: List[int] = None
|
430
459
|
extend_num_tokens: int = None
|
431
|
-
running_bs: int = None
|
432
460
|
decoding_reqs: List[Req] = None
|
433
461
|
|
462
|
+
# For encoder-decoder
|
463
|
+
encoder_cached: Optional[List[bool]] = None
|
464
|
+
encoder_lens: Optional[torch.Tensor] = None
|
465
|
+
encoder_lens_cpu: Optional[List[int]] = None
|
466
|
+
encoder_out_cache_loc: Optional[torch.Tensor] = None
|
467
|
+
|
434
468
|
# Stream
|
435
469
|
has_stream: bool = False
|
436
470
|
|
437
|
-
# device
|
438
|
-
device: str = "cuda"
|
439
|
-
|
440
471
|
# Has regex
|
441
472
|
has_regex: bool = False
|
442
473
|
|
443
|
-
|
444
|
-
|
445
|
-
return_logprob = any(req.return_logprob for req in reqs)
|
446
|
-
has_stream = any(req.stream for req in reqs)
|
447
|
-
has_regex = any(req.regex_fsm for req in reqs)
|
474
|
+
# device
|
475
|
+
device: str = "cuda"
|
448
476
|
|
477
|
+
@classmethod
|
478
|
+
def init_new(
|
479
|
+
cls,
|
480
|
+
reqs,
|
481
|
+
req_to_token_pool,
|
482
|
+
token_to_kv_pool,
|
483
|
+
tree_cache,
|
484
|
+
model_config,
|
485
|
+
):
|
449
486
|
return cls(
|
450
487
|
reqs=reqs,
|
451
488
|
req_to_token_pool=req_to_token_pool,
|
452
489
|
token_to_kv_pool=token_to_kv_pool,
|
453
490
|
tree_cache=tree_cache,
|
454
|
-
|
455
|
-
|
491
|
+
model_config=model_config,
|
492
|
+
return_logprob=any(req.return_logprob for req in reqs),
|
493
|
+
has_stream=any(req.stream for req in reqs),
|
494
|
+
has_regex=any(req.regex_fsm for req in reqs),
|
456
495
|
device=req_to_token_pool.device,
|
457
|
-
has_regex=has_regex,
|
458
496
|
)
|
459
497
|
|
460
498
|
def batch_size(self):
|
@@ -481,14 +519,90 @@ class ScheduleBatch:
|
|
481
519
|
out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
|
482
520
|
|
483
521
|
if out_cache_loc is None:
|
484
|
-
|
522
|
+
phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
|
523
|
+
logger.error(
|
524
|
+
f"{phase_str} out of memory. Try to lower your batch size.\n"
|
525
|
+
f"Try to allocate {num_tokens} tokens.\n"
|
526
|
+
f"Avaliable tokens: {self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()}\n"
|
527
|
+
)
|
485
528
|
if self.tree_cache is not None:
|
486
529
|
self.tree_cache.pretty_print()
|
487
530
|
exit(1)
|
488
531
|
|
489
532
|
return out_cache_loc
|
490
533
|
|
491
|
-
def
|
534
|
+
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
|
535
|
+
self.encoder_lens_cpu = []
|
536
|
+
self.encoder_cached = []
|
537
|
+
|
538
|
+
for req in self.reqs:
|
539
|
+
im = req.image_inputs
|
540
|
+
if im is None or im.num_image_tokens is None:
|
541
|
+
# No image input
|
542
|
+
self.encoder_lens_cpu.append(0)
|
543
|
+
self.encoder_cached.append(True)
|
544
|
+
else:
|
545
|
+
self.encoder_lens_cpu.append(im.num_image_tokens)
|
546
|
+
self.encoder_cached.append(
|
547
|
+
self.forward_mode.is_decode()
|
548
|
+
or len(req.prefix_indices) >= im.num_image_tokens
|
549
|
+
)
|
550
|
+
|
551
|
+
self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int32).to(
|
552
|
+
self.device, non_blocking=True
|
553
|
+
)
|
554
|
+
|
555
|
+
# Strip encoder infos
|
556
|
+
pt = 0
|
557
|
+
decoder_out_cache_loc = []
|
558
|
+
encoder_out_cache_loc = []
|
559
|
+
for i, req in enumerate(self.reqs):
|
560
|
+
encoder_len = self.encoder_lens_cpu[i]
|
561
|
+
seq_lens[i] -= encoder_len
|
562
|
+
|
563
|
+
if len(req.prefix_indices) < encoder_len:
|
564
|
+
# NOTE: the encoder part should considered as a whole
|
565
|
+
assert len(req.prefix_indices) == 0
|
566
|
+
input_ids[i] = input_ids[i][encoder_len:]
|
567
|
+
encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
|
568
|
+
decoder_out_cache_loc.append(
|
569
|
+
self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
|
570
|
+
)
|
571
|
+
self.extend_lens[i] -= encoder_len
|
572
|
+
self.extend_num_tokens -= encoder_len
|
573
|
+
else:
|
574
|
+
decoder_out_cache_loc.append(
|
575
|
+
self.out_cache_loc[pt : pt + req.extend_input_len]
|
576
|
+
)
|
577
|
+
self.prefix_lens[i] -= encoder_len
|
578
|
+
|
579
|
+
pt += req.extend_input_len
|
580
|
+
|
581
|
+
# Reassign
|
582
|
+
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
583
|
+
self.device, non_blocking=True
|
584
|
+
)
|
585
|
+
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
|
586
|
+
self.device, non_blocking=True
|
587
|
+
)
|
588
|
+
|
589
|
+
if not decoder_out_cache_loc:
|
590
|
+
self.out_cache_loc = torch.empty(0, dtype=torch.int32).to(
|
591
|
+
self.device, non_blocking=True
|
592
|
+
)
|
593
|
+
else:
|
594
|
+
self.out_cache_loc = torch.cat(decoder_out_cache_loc)
|
595
|
+
|
596
|
+
if not encoder_out_cache_loc:
|
597
|
+
self.encoder_out_cache_loc = torch.empty(0, dtype=torch.int32).to(
|
598
|
+
self.device, non_blocking=True
|
599
|
+
)
|
600
|
+
else:
|
601
|
+
self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)
|
602
|
+
|
603
|
+
assert len(self.out_cache_loc) == self.extend_num_tokens
|
604
|
+
|
605
|
+
def prepare_for_extend(self):
|
492
606
|
self.forward_mode = ForwardMode.EXTEND
|
493
607
|
|
494
608
|
bs = len(self.reqs)
|
@@ -516,12 +630,12 @@ class ScheduleBatch:
|
|
516
630
|
assert seq_len - pre_len == req.extend_input_len
|
517
631
|
|
518
632
|
if pre_len > 0:
|
519
|
-
self.req_to_token_pool.
|
520
|
-
req.prefix_indices
|
633
|
+
self.req_to_token_pool.write(
|
634
|
+
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
|
521
635
|
)
|
522
|
-
|
523
|
-
|
524
|
-
out_cache_loc[pt : pt + req.extend_input_len]
|
636
|
+
self.req_to_token_pool.write(
|
637
|
+
(req.req_pool_idx, slice(pre_len, seq_len)),
|
638
|
+
out_cache_loc[pt : pt + req.extend_input_len],
|
525
639
|
)
|
526
640
|
|
527
641
|
# Compute the relative logprob_start_len in an extend batch
|
@@ -546,16 +660,23 @@ class ScheduleBatch:
|
|
546
660
|
self.device, non_blocking=True
|
547
661
|
)
|
548
662
|
|
549
|
-
self.extend_num_tokens = extend_num_tokens
|
550
663
|
self.out_cache_loc = out_cache_loc
|
664
|
+
|
665
|
+
self.seq_lens_sum = sum(seq_lens)
|
551
666
|
if self.return_logprob:
|
552
667
|
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
668
|
+
self.extend_num_tokens = extend_num_tokens
|
553
669
|
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
|
554
670
|
self.extend_lens = [r.extend_input_len for r in reqs]
|
555
671
|
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
556
672
|
|
673
|
+
if self.model_config.is_encoder_decoder:
|
674
|
+
self.prepare_encoder_info_extend(input_ids, seq_lens)
|
675
|
+
|
557
676
|
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
558
|
-
self,
|
677
|
+
self,
|
678
|
+
self.model_config.vocab_size,
|
679
|
+
global_server_args_dict["disable_penalizer"],
|
559
680
|
)
|
560
681
|
|
561
682
|
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
@@ -568,12 +689,11 @@ class ScheduleBatch:
|
|
568
689
|
|
569
690
|
input_ids = torch.cat([self.input_ids, running_batch.input_ids])
|
570
691
|
out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
|
571
|
-
extend_num_tokens = self.extend_num_tokens + running_bs
|
572
692
|
|
573
693
|
self.merge_batch(running_batch)
|
574
694
|
self.input_ids = input_ids
|
575
695
|
self.out_cache_loc = out_cache_loc
|
576
|
-
self.extend_num_tokens
|
696
|
+
self.extend_num_tokens += running_bs
|
577
697
|
|
578
698
|
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
|
579
699
|
self.prefix_lens.extend(
|
@@ -631,8 +751,8 @@ class ScheduleBatch:
|
|
631
751
|
|
632
752
|
if isinstance(self.tree_cache, ChunkCache):
|
633
753
|
# ChunkCache does not have eviction
|
634
|
-
token_indices = self.req_to_token_pool.req_to_token[
|
635
|
-
: seq_lens_cpu[idx]
|
754
|
+
token_indices = self.req_to_token_pool.req_to_token[
|
755
|
+
req.req_pool_idx, : seq_lens_cpu[idx]
|
636
756
|
]
|
637
757
|
self.token_to_kv_pool.free(token_indices)
|
638
758
|
self.req_to_token_pool.free(req.req_pool_idx)
|
@@ -640,8 +760,8 @@ class ScheduleBatch:
|
|
640
760
|
else:
|
641
761
|
# TODO: apply more fine-grained retraction
|
642
762
|
last_uncached_pos = len(req.prefix_indices)
|
643
|
-
token_indices = self.req_to_token_pool.req_to_token[
|
644
|
-
last_uncached_pos : seq_lens_cpu[idx]
|
763
|
+
token_indices = self.req_to_token_pool.req_to_token[
|
764
|
+
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
|
645
765
|
]
|
646
766
|
self.token_to_kv_pool.free(token_indices)
|
647
767
|
self.req_to_token_pool.free(req.req_pool_idx)
|
@@ -746,7 +866,11 @@ class ScheduleBatch:
|
|
746
866
|
|
747
867
|
return jump_forward_reqs
|
748
868
|
|
749
|
-
def
|
869
|
+
def prepare_encoder_info_decode(self):
|
870
|
+
# Reset the encoder cached status
|
871
|
+
self.encoder_cached = [True] * len(self.reqs)
|
872
|
+
|
873
|
+
def prepare_for_decode(self, enable_overlap: bool = False):
|
750
874
|
self.forward_mode = ForwardMode.DECODE
|
751
875
|
|
752
876
|
self.input_ids = self.output_ids
|
@@ -760,10 +884,25 @@ class ScheduleBatch:
|
|
760
884
|
bs = len(self.reqs)
|
761
885
|
self.out_cache_loc = self.alloc_token_slots(bs)
|
762
886
|
|
763
|
-
self.
|
764
|
-
self.
|
765
|
-
|
766
|
-
|
887
|
+
if self.model_config.is_encoder_decoder:
|
888
|
+
locs = self.encoder_lens + self.seq_lens
|
889
|
+
self.prepare_encoder_info_decode()
|
890
|
+
else:
|
891
|
+
locs = self.seq_lens
|
892
|
+
|
893
|
+
if enable_overlap:
|
894
|
+
# Do not use in-place operations in the overlap mode
|
895
|
+
self.req_to_token_pool.write(
|
896
|
+
(self.req_pool_indices, locs), self.out_cache_loc
|
897
|
+
)
|
898
|
+
self.seq_lens = self.seq_lens + 1
|
899
|
+
else:
|
900
|
+
# A faster in-place version
|
901
|
+
self.req_to_token_pool.write(
|
902
|
+
(self.req_pool_indices, locs), self.out_cache_loc
|
903
|
+
)
|
904
|
+
self.seq_lens.add_(1)
|
905
|
+
self.seq_lens_sum += bs
|
767
906
|
|
768
907
|
def filter_batch(
|
769
908
|
self,
|
@@ -787,6 +926,10 @@ class ScheduleBatch:
|
|
787
926
|
# No need to filter
|
788
927
|
return
|
789
928
|
|
929
|
+
if self.model_config.is_encoder_decoder:
|
930
|
+
self.encoder_lens = self.encoder_lens[keep_indices]
|
931
|
+
self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
|
932
|
+
|
790
933
|
self.reqs = [self.reqs[i] for i in keep_indices]
|
791
934
|
new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
|
792
935
|
self.device, non_blocking=True
|
@@ -794,6 +937,7 @@ class ScheduleBatch:
|
|
794
937
|
self.req_pool_indices = self.req_pool_indices[new_indices]
|
795
938
|
self.seq_lens = self.seq_lens[new_indices]
|
796
939
|
self.out_cache_loc = None
|
940
|
+
self.seq_lens_sum = self.seq_lens.sum().item()
|
797
941
|
self.output_ids = self.output_ids[new_indices]
|
798
942
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
799
943
|
if self.return_logprob:
|
@@ -812,11 +956,17 @@ class ScheduleBatch:
|
|
812
956
|
# needs to be called with pre-merged Batch.reqs.
|
813
957
|
self.sampling_info.merge_batch(other.sampling_info)
|
814
958
|
|
959
|
+
# Encoder-decoder infos
|
960
|
+
if self.model_config.is_encoder_decoder:
|
961
|
+
self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
|
962
|
+
self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
|
963
|
+
|
815
964
|
self.req_pool_indices = torch.concat(
|
816
965
|
[self.req_pool_indices, other.req_pool_indices]
|
817
966
|
)
|
818
967
|
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
|
819
968
|
self.out_cache_loc = None
|
969
|
+
self.seq_lens_sum += other.seq_lens_sum
|
820
970
|
if self.output_ids is not None:
|
821
971
|
self.output_ids = torch.concat([self.output_ids, other.output_ids])
|
822
972
|
if self.return_logprob and other.return_logprob:
|
@@ -833,16 +983,12 @@ class ScheduleBatch:
|
|
833
983
|
|
834
984
|
def get_model_worker_batch(self):
|
835
985
|
if self.forward_mode.is_decode():
|
836
|
-
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens =
|
837
|
-
image_inputs
|
838
|
-
) = None
|
986
|
+
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
|
839
987
|
else:
|
840
988
|
extend_seq_lens = self.extend_lens
|
841
989
|
extend_prefix_lens = self.prefix_lens
|
842
990
|
extend_logprob_start_lens = self.extend_logprob_start_lens
|
843
|
-
image_inputs = [r.image_inputs for r in self.reqs]
|
844
991
|
|
845
|
-
lora_paths = [req.lora_path for req in self.reqs]
|
846
992
|
if self.has_regex:
|
847
993
|
self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs]
|
848
994
|
self.sampling_info.regex_fsm_states = [
|
@@ -854,6 +1000,8 @@ class ScheduleBatch:
|
|
854
1000
|
global bid
|
855
1001
|
bid += 1
|
856
1002
|
|
1003
|
+
mrope_positions_delta = [req.mrope_position_delta for req in self.reqs]
|
1004
|
+
|
857
1005
|
return ModelWorkerBatch(
|
858
1006
|
bid=bid,
|
859
1007
|
forward_mode=self.forward_mode,
|
@@ -861,19 +1009,29 @@ class ScheduleBatch:
|
|
861
1009
|
req_pool_indices=self.req_pool_indices,
|
862
1010
|
seq_lens=self.seq_lens,
|
863
1011
|
out_cache_loc=self.out_cache_loc,
|
1012
|
+
seq_lens_sum=self.seq_lens_sum,
|
1013
|
+
req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
|
864
1014
|
return_logprob=self.return_logprob,
|
865
1015
|
top_logprobs_nums=self.top_logprobs_nums,
|
1016
|
+
extend_num_tokens=self.extend_num_tokens,
|
866
1017
|
extend_seq_lens=extend_seq_lens,
|
867
1018
|
extend_prefix_lens=extend_prefix_lens,
|
868
1019
|
extend_logprob_start_lens=extend_logprob_start_lens,
|
869
|
-
image_inputs=image_inputs,
|
870
|
-
|
1020
|
+
image_inputs=[r.image_inputs for r in self.reqs],
|
1021
|
+
encoder_cached=self.encoder_cached,
|
1022
|
+
encoder_lens=self.encoder_lens,
|
1023
|
+
encoder_lens_cpu=self.encoder_lens_cpu,
|
1024
|
+
encoder_out_cache_loc=self.encoder_out_cache_loc,
|
1025
|
+
lora_paths=[req.lora_path for req in self.reqs],
|
871
1026
|
sampling_info=self.sampling_info,
|
1027
|
+
mrope_positions_delta=mrope_positions_delta,
|
872
1028
|
)
|
873
1029
|
|
874
1030
|
def copy(self):
|
1031
|
+
# Only contain fields that will be used by process_batch_result
|
875
1032
|
return ScheduleBatch(
|
876
1033
|
reqs=self.reqs,
|
1034
|
+
model_config=self.model_config,
|
877
1035
|
forward_mode=self.forward_mode,
|
878
1036
|
out_cache_loc=self.out_cache_loc,
|
879
1037
|
return_logprob=self.return_logprob,
|
@@ -887,7 +1045,7 @@ class ScheduleBatch:
|
|
887
1045
|
)
|
888
1046
|
|
889
1047
|
|
890
|
-
@dataclass
|
1048
|
+
@dataclasses.dataclass
|
891
1049
|
class ModelWorkerBatch:
|
892
1050
|
# The batch id
|
893
1051
|
bid: int
|
@@ -902,11 +1060,18 @@ class ModelWorkerBatch:
|
|
902
1060
|
# The indices of output tokens in the token_to_kv_pool
|
903
1061
|
out_cache_loc: torch.Tensor
|
904
1062
|
|
1063
|
+
# The sum of all sequence lengths
|
1064
|
+
seq_lens_sum: int
|
1065
|
+
|
1066
|
+
# The memory pool operation records
|
1067
|
+
req_to_token_pool_records: Optional[List[Tuple[Tuple, torch.Tensor]]]
|
1068
|
+
|
905
1069
|
# For logprob
|
906
1070
|
return_logprob: bool
|
907
1071
|
top_logprobs_nums: Optional[List[int]]
|
908
1072
|
|
909
1073
|
# For extend
|
1074
|
+
extend_num_tokens: Optional[int]
|
910
1075
|
extend_seq_lens: Optional[List[int]]
|
911
1076
|
extend_prefix_lens: Optional[List[int]]
|
912
1077
|
extend_logprob_start_lens: Optional[List[int]]
|
@@ -914,26 +1079,31 @@ class ModelWorkerBatch:
|
|
914
1079
|
# For multimodal
|
915
1080
|
image_inputs: Optional[List[ImageInputs]]
|
916
1081
|
|
1082
|
+
# For encoder-decoder
|
1083
|
+
encoder_cached: Optional[List[bool]]
|
1084
|
+
encoder_lens: Optional[torch.Tensor]
|
1085
|
+
encoder_lens_cpu: Optional[List[int]]
|
1086
|
+
encoder_out_cache_loc: Optional[torch.Tensor]
|
1087
|
+
|
917
1088
|
# For LoRA
|
918
1089
|
lora_paths: Optional[List[str]]
|
919
1090
|
|
920
1091
|
# Sampling info
|
921
1092
|
sampling_info: SamplingBatchInfo
|
922
1093
|
|
1094
|
+
# For Qwen2-VL
|
1095
|
+
mrope_positions_delta: List[List[int]]
|
1096
|
+
|
923
1097
|
def copy(self):
|
924
|
-
return
|
925
|
-
|
926
|
-
|
927
|
-
|
928
|
-
|
929
|
-
|
930
|
-
|
931
|
-
|
932
|
-
|
933
|
-
|
934
|
-
|
935
|
-
|
936
|
-
image_inputs=self.image_inputs,
|
937
|
-
lora_paths=self.lora_paths,
|
938
|
-
sampling_info=self.sampling_info.copy(),
|
939
|
-
)
|
1098
|
+
return dataclasses.replace(self, sampling_info=self.sampling_info.copy())
|
1099
|
+
|
1100
|
+
def to(self, device: str):
|
1101
|
+
self.input_ids = self.input_ids.to(device, non_blocking=True)
|
1102
|
+
self.req_pool_indices = self.req_pool_indices.to(device, non_blocking=True)
|
1103
|
+
self.seq_lens = self.seq_lens.to(device, non_blocking=True)
|
1104
|
+
self.out_cache_loc = self.out_cache_loc.to(device, non_blocking=True)
|
1105
|
+
self.req_to_token_pool_records = [
|
1106
|
+
(x, y.to(device, non_blocking=True))
|
1107
|
+
for x, y in self.req_to_token_pool_records
|
1108
|
+
]
|
1109
|
+
self.sampling_info.to(device)
|