sglang 0.3.3.post1__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +28 -10
- sglang/bench_server_latency.py +21 -10
- sglang/bench_serving.py +101 -7
- sglang/global_config.py +0 -1
- sglang/srt/layers/attention/__init__.py +27 -5
- sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
- sglang/srt/layers/attention/flashinfer_backend.py +352 -83
- sglang/srt/layers/attention/triton_backend.py +6 -4
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
- sglang/srt/layers/sampler.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +31 -10
- sglang/srt/managers/io_struct.py +4 -0
- sglang/srt/managers/schedule_batch.py +120 -43
- sglang/srt/managers/schedule_policy.py +2 -1
- sglang/srt/managers/scheduler.py +202 -140
- sglang/srt/managers/tokenizer_manager.py +5 -1
- sglang/srt/managers/tp_worker.py +111 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -4
- sglang/srt/mem_cache/memory_pool.py +77 -4
- sglang/srt/mem_cache/radix_cache.py +15 -7
- sglang/srt/model_executor/cuda_graph_runner.py +4 -4
- sglang/srt/model_executor/forward_batch_info.py +16 -21
- sglang/srt/model_executor/model_runner.py +60 -1
- sglang/srt/models/baichuan.py +2 -3
- sglang/srt/models/chatglm.py +5 -6
- sglang/srt/models/commandr.py +1 -2
- sglang/srt/models/dbrx.py +1 -2
- sglang/srt/models/deepseek.py +4 -5
- sglang/srt/models/deepseek_v2.py +5 -6
- sglang/srt/models/exaone.py +1 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +5 -5
- sglang/srt/models/gpt_bigcode.py +5 -5
- sglang/srt/models/grok.py +1 -2
- sglang/srt/models/internlm2.py +1 -2
- sglang/srt/models/llama.py +1 -2
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +4 -8
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +1 -2
- sglang/srt/models/minicpm3.py +5 -6
- sglang/srt/models/mixtral.py +1 -2
- sglang/srt/models/mixtral_quant.py +1 -2
- sglang/srt/models/olmo.py +352 -0
- sglang/srt/models/olmoe.py +1 -2
- sglang/srt/models/qwen.py +1 -2
- sglang/srt/models/qwen2.py +1 -2
- sglang/srt/models/qwen2_moe.py +4 -5
- sglang/srt/models/stablelm.py +1 -2
- sglang/srt/models/torch_native_llama.py +1 -2
- sglang/srt/models/xverse.py +1 -2
- sglang/srt/models/xverse_moe.py +4 -5
- sglang/srt/models/yivl.py +1 -2
- sglang/srt/openai_api/adapter.py +92 -49
- sglang/srt/openai_api/protocol.py +10 -2
- sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
- sglang/srt/sampling/sampling_batch_info.py +92 -58
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server.py +116 -17
- sglang/srt/server_args.py +121 -45
- sglang/srt/utils.py +11 -3
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/few_shot_gsm8k_engine.py +144 -0
- sglang/test/srt/sampling/penaltylib/utils.py +16 -12
- sglang/version.py +1 -1
- {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/METADATA +72 -29
- {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/RECORD +73 -70
- {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
- sglang/srt/layers/attention/flashinfer_utils.py +0 -237
- {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
- {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
@@ -26,7 +26,9 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
|
26
26
|
context_attention_fwd,
|
27
27
|
)
|
28
28
|
|
29
|
-
|
29
|
+
is_cuda_available = torch.cuda.is_available()
|
30
|
+
if is_cuda_available:
|
31
|
+
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
30
32
|
|
31
33
|
|
32
34
|
@triton.jit
|
@@ -286,12 +288,12 @@ def extend_attention_fwd(
|
|
286
288
|
BLOCK_DPE = 0
|
287
289
|
BLOCK_DV = triton.next_power_of_2(Lv)
|
288
290
|
|
289
|
-
if CUDA_CAPABILITY[0] >= 9:
|
291
|
+
if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
|
290
292
|
if Lq <= 256:
|
291
293
|
BLOCK_M, BLOCK_N = (128, 64)
|
292
294
|
else:
|
293
295
|
BLOCK_M, BLOCK_N = (32, 64)
|
294
|
-
elif CUDA_CAPABILITY[0] >= 8:
|
296
|
+
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
|
295
297
|
if Lq <= 128:
|
296
298
|
BLOCK_M, BLOCK_N = (128, 128)
|
297
299
|
elif Lq <= 256:
|
@@ -24,7 +24,9 @@ import torch
|
|
24
24
|
import triton
|
25
25
|
import triton.language as tl
|
26
26
|
|
27
|
-
|
27
|
+
is_cuda_available = torch.cuda.is_available()
|
28
|
+
if is_cuda_available:
|
29
|
+
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
28
30
|
|
29
31
|
|
30
32
|
@triton.jit
|
@@ -145,7 +147,7 @@ def _fwd_kernel(
|
|
145
147
|
|
146
148
|
|
147
149
|
def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
148
|
-
if CUDA_CAPABILITY[0] >= 8:
|
150
|
+
if is_cuda_available and CUDA_CAPABILITY[0] >= 8:
|
149
151
|
BLOCK = 128
|
150
152
|
else:
|
151
153
|
BLOCK = 64
|
sglang/srt/layers/sampler.py
CHANGED
@@ -21,6 +21,10 @@ logger = logging.getLogger(__name__)
|
|
21
21
|
|
22
22
|
|
23
23
|
class Sampler(nn.Module):
|
24
|
+
def __init__(self):
|
25
|
+
super().__init__()
|
26
|
+
self.use_nan_detectioin = not global_server_args_dict["disable_nan_detection"]
|
27
|
+
|
24
28
|
def forward(
|
25
29
|
self,
|
26
30
|
logits: Union[torch.Tensor, LogitsProcessorOutput],
|
@@ -36,13 +40,13 @@ class Sampler(nn.Module):
|
|
36
40
|
logits = None
|
37
41
|
del logits
|
38
42
|
|
39
|
-
if torch.any(torch.isnan(probs)):
|
43
|
+
if self.use_nan_detectioin and torch.any(torch.isnan(probs)):
|
40
44
|
logger.warning("Detected errors during sampling! NaN in the probability.")
|
41
45
|
probs = torch.where(
|
42
46
|
torch.isnan(probs), torch.full_like(probs, 1e-10), probs
|
43
47
|
)
|
44
48
|
|
45
|
-
if sampling_info.
|
49
|
+
if sampling_info.is_all_greedy:
|
46
50
|
# Use torch.argmax if all requests use greedy sampling
|
47
51
|
batch_next_token_ids = torch.argmax(probs, -1)
|
48
52
|
elif global_server_args_dict["sampling_backend"] == "flashinfer":
|
@@ -18,7 +18,7 @@ limitations under the License.
|
|
18
18
|
import dataclasses
|
19
19
|
import logging
|
20
20
|
from collections import OrderedDict
|
21
|
-
from typing import List
|
21
|
+
from typing import List, Union
|
22
22
|
|
23
23
|
import zmq
|
24
24
|
|
@@ -29,7 +29,7 @@ from sglang.srt.managers.io_struct import (
|
|
29
29
|
BatchTokenIDOut,
|
30
30
|
UpdateWeightReqOutput,
|
31
31
|
)
|
32
|
-
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
|
32
|
+
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN
|
33
33
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
34
34
|
from sglang.srt.utils import configure_logger, kill_parent_process
|
35
35
|
from sglang.utils import find_printable_text, get_exception_traceback
|
@@ -75,6 +75,21 @@ class DetokenizerManager:
|
|
75
75
|
|
76
76
|
self.decode_status = LimitedCapacityDict()
|
77
77
|
|
78
|
+
def trim_eos(self, output: Union[str, List[int]], finished_reason, no_stop_trim):
|
79
|
+
if no_stop_trim:
|
80
|
+
return output
|
81
|
+
|
82
|
+
# Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit
|
83
|
+
if isinstance(finished_reason, FINISH_MATCHED_STR) and isinstance(output, str):
|
84
|
+
pos = output.find(finished_reason.matched)
|
85
|
+
return output[:pos] if pos != -1 else output
|
86
|
+
if isinstance(finished_reason, FINISH_MATCHED_TOKEN) and isinstance(
|
87
|
+
output, list
|
88
|
+
):
|
89
|
+
assert len(output) > 0
|
90
|
+
return output[:-1]
|
91
|
+
return output
|
92
|
+
|
78
93
|
def event_loop(self):
|
79
94
|
"""The event loop that handles requests"""
|
80
95
|
|
@@ -122,7 +137,13 @@ class DetokenizerManager:
|
|
122
137
|
s = self.decode_status[rid]
|
123
138
|
s.decode_ids = recv_obj.decode_ids[i]
|
124
139
|
|
125
|
-
read_ids.append(
|
140
|
+
read_ids.append(
|
141
|
+
self.trim_eos(
|
142
|
+
s.decode_ids[s.surr_offset :],
|
143
|
+
recv_obj.finished_reason[i],
|
144
|
+
recv_obj.no_stop_trim[i],
|
145
|
+
)
|
146
|
+
)
|
126
147
|
surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset])
|
127
148
|
|
128
149
|
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
|
@@ -152,13 +173,13 @@ class DetokenizerManager:
|
|
152
173
|
else:
|
153
174
|
new_text = find_printable_text(new_text)
|
154
175
|
|
155
|
-
output_strs.append(
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
176
|
+
output_strs.append(
|
177
|
+
self.trim_eos(
|
178
|
+
s.decoded_text + new_text,
|
179
|
+
recv_obj.finished_reason[i],
|
180
|
+
recv_obj.no_stop_trim[i],
|
181
|
+
)
|
182
|
+
)
|
162
183
|
|
163
184
|
self.send_to_tokenizer.send_pyobj(
|
164
185
|
BatchStrOut(
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -56,6 +56,9 @@ class GenerateReqInput:
|
|
56
56
|
# LoRA related
|
57
57
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
58
58
|
|
59
|
+
# Whether it is a single request or a batch request
|
60
|
+
is_single: bool = True
|
61
|
+
|
59
62
|
def post_init(self):
|
60
63
|
if (self.text is None and self.input_ids is None) or (
|
61
64
|
self.text is not None and self.input_ids is not None
|
@@ -295,6 +298,7 @@ class BatchTokenIDOut:
|
|
295
298
|
spaces_between_special_tokens: List[bool]
|
296
299
|
meta_info: List[Dict]
|
297
300
|
finished_reason: List[BaseFinishReason]
|
301
|
+
no_stop_trim: List[bool]
|
298
302
|
|
299
303
|
|
300
304
|
@dataclass
|
@@ -53,6 +53,7 @@ global_server_args_dict = {
|
|
53
53
|
"triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
|
54
54
|
"disable_mla": ServerArgs.disable_mla,
|
55
55
|
"torchao_config": ServerArgs.torchao_config,
|
56
|
+
"disable_nan_detection": ServerArgs.disable_nan_detection,
|
56
57
|
}
|
57
58
|
|
58
59
|
|
@@ -196,6 +197,9 @@ class Req:
|
|
196
197
|
# this does not include the jump forward tokens.
|
197
198
|
self.completion_tokens_wo_jump_forward = 0
|
198
199
|
|
200
|
+
# The number of cached tokens, that were already cached in the KV store
|
201
|
+
self.cached_tokens = 0
|
202
|
+
|
199
203
|
# For vision inputs
|
200
204
|
self.image_inputs: Optional[ImageInputs] = None
|
201
205
|
|
@@ -203,6 +207,7 @@ class Req:
|
|
203
207
|
self.prefix_indices = []
|
204
208
|
self.extend_input_len = 0
|
205
209
|
self.last_node = None
|
210
|
+
self.is_inflight_req = 0
|
206
211
|
|
207
212
|
# Logprobs (arguments)
|
208
213
|
self.return_logprob = False
|
@@ -391,25 +396,30 @@ class Req:
|
|
391
396
|
return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
|
392
397
|
|
393
398
|
|
399
|
+
bid = 0
|
400
|
+
|
401
|
+
|
394
402
|
@dataclass
|
395
403
|
class ScheduleBatch:
|
396
404
|
"""Store all inforamtion of a batch."""
|
397
405
|
|
398
406
|
# Request, memory pool, and cache
|
399
407
|
reqs: List[Req]
|
400
|
-
req_to_token_pool: ReqToTokenPool
|
401
|
-
token_to_kv_pool: BaseTokenToKVPool
|
402
|
-
tree_cache: BasePrefixCache
|
408
|
+
req_to_token_pool: ReqToTokenPool = None
|
409
|
+
token_to_kv_pool: BaseTokenToKVPool = None
|
410
|
+
tree_cache: BasePrefixCache = None
|
403
411
|
|
404
412
|
forward_mode: ForwardMode = None
|
405
413
|
sampling_info: SamplingBatchInfo = None
|
406
414
|
|
407
415
|
# Batched arguments to model runner
|
408
|
-
input_ids:
|
409
|
-
req_pool_indices:
|
410
|
-
seq_lens:
|
416
|
+
input_ids: torch.Tensor = None
|
417
|
+
req_pool_indices: torch.Tensor = None
|
418
|
+
seq_lens: torch.Tensor = None
|
411
419
|
out_cache_loc: torch.Tensor = None
|
412
420
|
|
421
|
+
output_ids: torch.Tensor = None
|
422
|
+
|
413
423
|
# For processing logprobs
|
414
424
|
return_logprob: bool = False
|
415
425
|
top_logprobs_nums: Optional[List[int]] = None
|
@@ -419,6 +429,7 @@ class ScheduleBatch:
|
|
419
429
|
extend_lens: List[int] = None
|
420
430
|
extend_num_tokens: int = None
|
421
431
|
running_bs: int = None
|
432
|
+
decoding_reqs: List[Req] = None
|
422
433
|
|
423
434
|
# Stream
|
424
435
|
has_stream: bool = False
|
@@ -492,17 +503,24 @@ class ScheduleBatch:
|
|
492
503
|
|
493
504
|
pt = 0
|
494
505
|
for i, req in enumerate(reqs):
|
506
|
+
already_computed = (
|
507
|
+
req.extend_logprob_start_len + 1 + req.cached_tokens
|
508
|
+
if req.extend_logprob_start_len > 0
|
509
|
+
else 0
|
510
|
+
)
|
511
|
+
req.cached_tokens += len(req.prefix_indices) - already_computed
|
512
|
+
|
495
513
|
req.req_pool_idx = req_pool_indices[i]
|
496
514
|
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
|
497
515
|
seq_lens.append(seq_len)
|
498
516
|
assert seq_len - pre_len == req.extend_input_len
|
499
517
|
|
500
518
|
if pre_len > 0:
|
501
|
-
self.req_to_token_pool.req_to_token[req.req_pool_idx]
|
502
|
-
|
503
|
-
|
519
|
+
self.req_to_token_pool.req_to_token[req.req_pool_idx, :pre_len] = (
|
520
|
+
req.prefix_indices
|
521
|
+
)
|
504
522
|
|
505
|
-
self.req_to_token_pool.req_to_token[req.req_pool_idx
|
523
|
+
self.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
|
506
524
|
out_cache_loc[pt : pt + req.extend_input_len]
|
507
525
|
)
|
508
526
|
|
@@ -518,10 +536,15 @@ class ScheduleBatch:
|
|
518
536
|
pt += req.extend_input_len
|
519
537
|
|
520
538
|
# Set fields
|
521
|
-
|
522
|
-
self.
|
523
|
-
|
524
|
-
|
539
|
+
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
540
|
+
self.device, non_blocking=True
|
541
|
+
)
|
542
|
+
self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32).to(
|
543
|
+
self.device, non_blocking=True
|
544
|
+
)
|
545
|
+
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
|
546
|
+
self.device, non_blocking=True
|
547
|
+
)
|
525
548
|
|
526
549
|
self.extend_num_tokens = extend_num_tokens
|
527
550
|
self.out_cache_loc = out_cache_loc
|
@@ -531,7 +554,9 @@ class ScheduleBatch:
|
|
531
554
|
self.extend_lens = [r.extend_input_len for r in reqs]
|
532
555
|
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
533
556
|
|
534
|
-
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
557
|
+
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
558
|
+
self, vocab_size, global_server_args_dict["disable_penalizer"]
|
559
|
+
)
|
535
560
|
|
536
561
|
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
537
562
|
self.forward_mode = ForwardMode.MIXED
|
@@ -586,9 +611,11 @@ class ScheduleBatch:
|
|
586
611
|
|
587
612
|
retracted_reqs = []
|
588
613
|
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
614
|
+
first_iter = True
|
589
615
|
while (
|
590
616
|
self.token_to_kv_pool.available_size()
|
591
617
|
< len(sorted_indices) * global_config.retract_decode_steps
|
618
|
+
or first_iter
|
592
619
|
):
|
593
620
|
if len(sorted_indices) == 1:
|
594
621
|
# Corner case: only one request left
|
@@ -597,6 +624,7 @@ class ScheduleBatch:
|
|
597
624
|
), "No space left for only one request"
|
598
625
|
break
|
599
626
|
|
627
|
+
first_iter = False
|
600
628
|
idx = sorted_indices.pop()
|
601
629
|
req = self.reqs[idx]
|
602
630
|
retracted_reqs.append(req)
|
@@ -637,7 +665,7 @@ class ScheduleBatch:
|
|
637
665
|
req.last_update_decode_tokens = 0
|
638
666
|
req.logprob_start_len = 10**9
|
639
667
|
|
640
|
-
self.filter_batch(sorted_indices)
|
668
|
+
self.filter_batch(keep_indices=sorted_indices)
|
641
669
|
|
642
670
|
# Reqs in batch are filtered
|
643
671
|
total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
|
@@ -652,7 +680,7 @@ class ScheduleBatch:
|
|
652
680
|
|
653
681
|
def check_for_jump_forward(self, pad_input_ids_func):
|
654
682
|
jump_forward_reqs = []
|
655
|
-
|
683
|
+
keep_indices = set(i for i in range(len(self.reqs)))
|
656
684
|
|
657
685
|
for i, req in enumerate(self.reqs):
|
658
686
|
if req.jump_forward_map is not None:
|
@@ -712,63 +740,71 @@ class ScheduleBatch:
|
|
712
740
|
)
|
713
741
|
|
714
742
|
jump_forward_reqs.append(req)
|
715
|
-
|
743
|
+
keep_indices.remove(i)
|
716
744
|
|
717
|
-
self.filter_batch(
|
745
|
+
self.filter_batch(keep_indices=list(keep_indices))
|
718
746
|
|
719
747
|
return jump_forward_reqs
|
720
748
|
|
721
|
-
def prepare_for_decode(self
|
749
|
+
def prepare_for_decode(self):
|
722
750
|
self.forward_mode = ForwardMode.DECODE
|
723
751
|
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
self.input_ids = torch.tensor(
|
731
|
-
input_ids, dtype=torch.int32, device=self.seq_lens.device
|
732
|
-
)
|
733
|
-
self.seq_lens.add_(1)
|
752
|
+
self.input_ids = self.output_ids
|
753
|
+
self.output_ids = None
|
754
|
+
if self.sampling_info.penalizer_orchestrator:
|
755
|
+
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
756
|
+
self.input_ids
|
757
|
+
)
|
734
758
|
|
735
759
|
# Alloc mem
|
736
760
|
bs = len(self.reqs)
|
737
761
|
self.out_cache_loc = self.alloc_token_slots(bs)
|
738
762
|
|
739
|
-
self.req_to_token_pool.req_to_token[
|
740
|
-
self.
|
741
|
-
|
763
|
+
self.req_to_token_pool.req_to_token[self.req_pool_indices, self.seq_lens] = (
|
764
|
+
self.out_cache_loc
|
765
|
+
)
|
766
|
+
self.seq_lens.add_(1)
|
742
767
|
|
743
|
-
def filter_batch(
|
744
|
-
|
768
|
+
def filter_batch(
|
769
|
+
self,
|
770
|
+
current_inflight_req: Optional[Req] = None,
|
771
|
+
keep_indices: Optional[List[int]] = None,
|
772
|
+
):
|
773
|
+
if keep_indices is None:
|
774
|
+
keep_indices = [
|
775
|
+
i
|
776
|
+
for i in range(len(self.reqs))
|
777
|
+
if not self.reqs[i].finished()
|
778
|
+
and self.reqs[i] is not current_inflight_req
|
779
|
+
]
|
780
|
+
|
781
|
+
if keep_indices is None or len(keep_indices) == 0:
|
745
782
|
# Filter out all requests
|
746
783
|
self.reqs = []
|
747
784
|
return
|
748
785
|
|
749
|
-
if len(
|
786
|
+
if len(keep_indices) == len(self.reqs):
|
750
787
|
# No need to filter
|
751
788
|
return
|
752
789
|
|
753
|
-
self.reqs = [self.reqs[i] for i in
|
754
|
-
new_indices = torch.tensor(
|
755
|
-
|
790
|
+
self.reqs = [self.reqs[i] for i in keep_indices]
|
791
|
+
new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
|
792
|
+
self.device, non_blocking=True
|
756
793
|
)
|
757
794
|
self.req_pool_indices = self.req_pool_indices[new_indices]
|
758
795
|
self.seq_lens = self.seq_lens[new_indices]
|
759
796
|
self.out_cache_loc = None
|
797
|
+
self.output_ids = self.output_ids[new_indices]
|
760
798
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
761
799
|
if self.return_logprob:
|
762
|
-
self.top_logprobs_nums = [
|
763
|
-
self.top_logprobs_nums[i] for i in unfinished_indices
|
764
|
-
]
|
800
|
+
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
|
765
801
|
else:
|
766
802
|
self.top_logprobs_nums = None
|
767
803
|
|
768
804
|
self.has_stream = any(req.stream for req in self.reqs)
|
769
805
|
self.has_regex = any(req.regex_fsm for req in self.reqs)
|
770
806
|
|
771
|
-
self.sampling_info.filter_batch(
|
807
|
+
self.sampling_info.filter_batch(keep_indices, new_indices)
|
772
808
|
|
773
809
|
def merge_batch(self, other: "ScheduleBatch"):
|
774
810
|
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
@@ -781,6 +817,8 @@ class ScheduleBatch:
|
|
781
817
|
)
|
782
818
|
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
|
783
819
|
self.out_cache_loc = None
|
820
|
+
if self.output_ids is not None:
|
821
|
+
self.output_ids = torch.concat([self.output_ids, other.output_ids])
|
784
822
|
if self.return_logprob and other.return_logprob:
|
785
823
|
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
786
824
|
elif self.return_logprob:
|
@@ -813,7 +851,11 @@ class ScheduleBatch:
|
|
813
851
|
else:
|
814
852
|
self.sampling_info.regex_fsms = None
|
815
853
|
|
854
|
+
global bid
|
855
|
+
bid += 1
|
856
|
+
|
816
857
|
return ModelWorkerBatch(
|
858
|
+
bid=bid,
|
817
859
|
forward_mode=self.forward_mode,
|
818
860
|
input_ids=self.input_ids,
|
819
861
|
req_pool_indices=self.req_pool_indices,
|
@@ -829,9 +871,26 @@ class ScheduleBatch:
|
|
829
871
|
sampling_info=self.sampling_info,
|
830
872
|
)
|
831
873
|
|
874
|
+
def copy(self):
|
875
|
+
return ScheduleBatch(
|
876
|
+
reqs=self.reqs,
|
877
|
+
forward_mode=self.forward_mode,
|
878
|
+
out_cache_loc=self.out_cache_loc,
|
879
|
+
return_logprob=self.return_logprob,
|
880
|
+
decoding_reqs=self.decoding_reqs,
|
881
|
+
)
|
882
|
+
|
883
|
+
def __str__(self):
|
884
|
+
return (
|
885
|
+
f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
|
886
|
+
f"#req={(len(self.reqs))})"
|
887
|
+
)
|
888
|
+
|
832
889
|
|
833
890
|
@dataclass
|
834
891
|
class ModelWorkerBatch:
|
892
|
+
# The batch id
|
893
|
+
bid: int
|
835
894
|
# The forward mode
|
836
895
|
forward_mode: ForwardMode
|
837
896
|
# The input ids
|
@@ -860,3 +919,21 @@ class ModelWorkerBatch:
|
|
860
919
|
|
861
920
|
# Sampling info
|
862
921
|
sampling_info: SamplingBatchInfo
|
922
|
+
|
923
|
+
def copy(self):
|
924
|
+
return ModelWorkerBatch(
|
925
|
+
bid=self.bid,
|
926
|
+
forward_mode=self.forward_mode,
|
927
|
+
input_ids=self.input_ids.clone(),
|
928
|
+
req_pool_indices=self.req_pool_indices,
|
929
|
+
seq_lens=self.seq_lens.clone(),
|
930
|
+
out_cache_loc=self.out_cache_loc,
|
931
|
+
return_logprob=self.return_logprob,
|
932
|
+
top_logprobs_nums=self.top_logprobs_nums,
|
933
|
+
extend_seq_lens=self.extend_seq_lens,
|
934
|
+
extend_prefix_lens=self.extend_prefix_lens,
|
935
|
+
extend_logprob_start_lens=self.extend_logprob_start_lens,
|
936
|
+
image_inputs=self.image_inputs,
|
937
|
+
lora_paths=self.lora_paths,
|
938
|
+
sampling_info=self.sampling_info.copy(),
|
939
|
+
)
|
@@ -45,12 +45,13 @@ class SchedulePolicy:
|
|
45
45
|
def calc_priority(self, waiting_queue: List[Req]):
|
46
46
|
# Compute matched prefix length
|
47
47
|
prefix_computed = False
|
48
|
-
if self.policy
|
48
|
+
if self.policy == "lpm" or self.policy == "dfs-weight":
|
49
49
|
for r in waiting_queue:
|
50
50
|
# NOTE: the prefix_indices must always be aligned with last_node
|
51
51
|
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
|
52
52
|
rid=r.rid, key=r.adjust_max_prefix_ids()
|
53
53
|
)
|
54
|
+
|
54
55
|
prefix_computed = True
|
55
56
|
|
56
57
|
if self.policy == "lpm":
|