sglang 0.3.4__py3-none-any.whl → 0.3.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +2 -1
- sglang/lang/chat_template.py +17 -0
- sglang/launch_server_llavavid.py +1 -1
- sglang/srt/configs/__init__.py +3 -0
- sglang/srt/configs/model_config.py +27 -2
- sglang/srt/configs/qwen2vl.py +133 -0
- sglang/srt/constrained/fsm_cache.py +10 -3
- sglang/srt/conversation.py +27 -0
- sglang/srt/hf_transformers_utils.py +16 -1
- sglang/srt/layers/attention/__init__.py +16 -5
- sglang/srt/layers/attention/double_sparsity_backend.py +22 -6
- sglang/srt/layers/attention/flashinfer_backend.py +174 -54
- sglang/srt/layers/attention/triton_backend.py +22 -6
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +26 -4
- sglang/srt/layers/linear.py +89 -63
- sglang/srt/layers/logits_processor.py +5 -5
- sglang/srt/layers/rotary_embedding.py +112 -0
- sglang/srt/layers/sampler.py +51 -39
- sglang/srt/lora/lora.py +3 -1
- sglang/srt/managers/data_parallel_controller.py +1 -1
- sglang/srt/managers/detokenizer_manager.py +4 -0
- sglang/srt/managers/image_processor.py +186 -13
- sglang/srt/managers/io_struct.py +10 -0
- sglang/srt/managers/schedule_batch.py +238 -68
- sglang/srt/managers/scheduler.py +69 -50
- sglang/srt/managers/tokenizer_manager.py +24 -4
- sglang/srt/managers/tp_worker.py +26 -111
- sglang/srt/managers/tp_worker_overlap_thread.py +209 -0
- sglang/srt/mem_cache/memory_pool.py +56 -10
- sglang/srt/mem_cache/radix_cache.py +4 -3
- sglang/srt/model_executor/cuda_graph_runner.py +87 -28
- sglang/srt/model_executor/forward_batch_info.py +83 -3
- sglang/srt/model_executor/model_runner.py +32 -11
- sglang/srt/models/chatglm.py +3 -3
- sglang/srt/models/deepseek_v2.py +2 -2
- sglang/srt/models/mllama.py +1004 -0
- sglang/srt/models/qwen2_vl.py +724 -0
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +6 -3
- sglang/srt/sampling/sampling_batch_info.py +13 -3
- sglang/srt/sampling/sampling_params.py +5 -7
- sglang/srt/server.py +12 -0
- sglang/srt/server_args.py +10 -0
- sglang/srt/utils.py +22 -0
- sglang/test/run_eval.py +2 -0
- sglang/test/runners.py +20 -1
- sglang/test/srt/sampling/penaltylib/utils.py +1 -0
- sglang/test/test_utils.py +100 -3
- sglang/version.py +1 -1
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/METADATA +17 -18
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/RECORD +53 -48
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/LICENSE +0 -0
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/WHEEL +0 -0
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -38,6 +38,8 @@ from sglang.srt.managers.io_struct import (
|
|
38
38
|
BatchEmbeddingOut,
|
39
39
|
BatchTokenIDOut,
|
40
40
|
FlushCacheReq,
|
41
|
+
GetMemPoolSizeReq,
|
42
|
+
GetMemPoolSizeReqOutput,
|
41
43
|
ProfileReq,
|
42
44
|
TokenizedEmbeddingReqInput,
|
43
45
|
TokenizedGenerateReqInput,
|
@@ -51,6 +53,7 @@ from sglang.srt.managers.schedule_batch import (
|
|
51
53
|
ImageInputs,
|
52
54
|
Req,
|
53
55
|
ScheduleBatch,
|
56
|
+
global_server_args_dict,
|
54
57
|
)
|
55
58
|
from sglang.srt.managers.schedule_policy import (
|
56
59
|
AddReqResult,
|
@@ -58,6 +61,7 @@ from sglang.srt.managers.schedule_policy import (
|
|
58
61
|
SchedulePolicy,
|
59
62
|
)
|
60
63
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
64
|
+
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
61
65
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
62
66
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
63
67
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
@@ -67,7 +71,6 @@ from sglang.srt.utils import (
|
|
67
71
|
is_generation_model,
|
68
72
|
is_multimodal_model,
|
69
73
|
kill_parent_process,
|
70
|
-
pytorch_profile,
|
71
74
|
set_random_seed,
|
72
75
|
suppress_other_loggers,
|
73
76
|
)
|
@@ -91,6 +94,7 @@ class Scheduler:
|
|
91
94
|
port_args: PortArgs,
|
92
95
|
gpu_id: int,
|
93
96
|
tp_rank: int,
|
97
|
+
dp_rank: Optional[int],
|
94
98
|
):
|
95
99
|
# Parse args
|
96
100
|
self.server_args = server_args
|
@@ -100,6 +104,7 @@ class Scheduler:
|
|
100
104
|
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
101
105
|
self.lora_paths = server_args.lora_paths
|
102
106
|
self.max_loras_per_batch = server_args.max_loras_per_batch
|
107
|
+
self.enable_overlap = server_args.enable_overlap_schedule
|
103
108
|
|
104
109
|
# Init inter-process communication
|
105
110
|
context = zmq.Context(2)
|
@@ -143,27 +148,37 @@ class Scheduler:
|
|
143
148
|
)
|
144
149
|
|
145
150
|
# Launch a tensor parallel worker
|
146
|
-
self.
|
151
|
+
if self.enable_overlap:
|
152
|
+
TpWorkerClass = TpModelWorkerClient
|
153
|
+
else:
|
154
|
+
TpWorkerClass = TpModelWorker
|
155
|
+
|
156
|
+
self.tp_worker = TpWorkerClass(
|
157
|
+
server_args=server_args,
|
147
158
|
gpu_id=gpu_id,
|
148
159
|
tp_rank=tp_rank,
|
149
|
-
|
160
|
+
dp_rank=dp_rank,
|
150
161
|
nccl_port=port_args.nccl_port,
|
151
162
|
)
|
152
|
-
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
|
153
|
-
self.device = self.tp_worker.device
|
154
163
|
|
155
164
|
# Get token and memory info from the model worker
|
156
165
|
(
|
157
166
|
self.max_total_num_tokens,
|
158
167
|
self.max_prefill_tokens,
|
159
168
|
self.max_running_requests,
|
169
|
+
self.max_req_len,
|
160
170
|
self.max_req_input_len,
|
161
171
|
self.random_seed,
|
162
|
-
|
172
|
+
self.device,
|
173
|
+
worker_global_server_args_dict,
|
174
|
+
_,
|
175
|
+
_,
|
176
|
+
_,
|
177
|
+
) = self.tp_worker.get_worker_info()
|
178
|
+
self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
|
179
|
+
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
|
180
|
+
global_server_args_dict.update(worker_global_server_args_dict)
|
163
181
|
set_random_seed(self.random_seed)
|
164
|
-
self.pad_input_ids_func = getattr(
|
165
|
-
self.tp_worker.model_runner.model, "pad_input_ids", None
|
166
|
-
)
|
167
182
|
|
168
183
|
# Print debug info
|
169
184
|
logger.info(
|
@@ -173,9 +188,8 @@ class Scheduler:
|
|
173
188
|
f"context_len={self.model_config.context_len}"
|
174
189
|
)
|
175
190
|
|
176
|
-
# Init cache
|
177
|
-
self.req_to_token_pool = self.tp_worker.
|
178
|
-
self.token_to_kv_pool = self.tp_worker.model_runner.token_to_kv_pool
|
191
|
+
# Init memory pool and cache
|
192
|
+
self.req_to_token_pool, self.token_to_kv_pool = self.tp_worker.get_memory_pool()
|
179
193
|
|
180
194
|
if (
|
181
195
|
server_args.chunked_prefill_size is not None
|
@@ -253,22 +267,9 @@ class Scheduler:
|
|
253
267
|
with_stack=True,
|
254
268
|
)
|
255
269
|
|
256
|
-
# Init states for overlap schedule
|
257
|
-
if self.server_args.enable_overlap_schedule:
|
258
|
-
self.forward_batch_generation = (
|
259
|
-
self.tp_worker.forward_batch_generation_non_blocking
|
260
|
-
)
|
261
|
-
self.resolve_next_token_ids = (
|
262
|
-
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
|
263
|
-
)
|
264
|
-
self.cache_finished_req = self.tree_cache.cache_finished_req
|
265
|
-
else:
|
266
|
-
self.forward_batch_generation = self.tp_worker.forward_batch_generation
|
267
|
-
self.resolve_next_token_ids = lambda bid, x: x.tolist()
|
268
|
-
self.cache_finished_req = self.tree_cache.cache_finished_req
|
269
|
-
|
270
270
|
@torch.inference_mode()
|
271
271
|
def event_loop_normal(self):
|
272
|
+
"""A normal blocking scheduler loop."""
|
272
273
|
self.last_batch = None
|
273
274
|
|
274
275
|
while True:
|
@@ -299,6 +300,7 @@ class Scheduler:
|
|
299
300
|
|
300
301
|
@torch.inference_mode()
|
301
302
|
def event_loop_overlap(self):
|
303
|
+
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
|
302
304
|
result_queue = deque()
|
303
305
|
|
304
306
|
self.last_batch = None
|
@@ -362,6 +364,10 @@ class Scheduler:
|
|
362
364
|
self.start_profile()
|
363
365
|
else:
|
364
366
|
self.stop_profile()
|
367
|
+
elif isinstance(recv_req, GetMemPoolSizeReq):
|
368
|
+
self.send_to_detokenizer.send_pyobj(
|
369
|
+
GetMemPoolSizeReqOutput(self.max_total_num_tokens)
|
370
|
+
)
|
365
371
|
else:
|
366
372
|
raise ValueError(f"Invalid request: {recv_req}")
|
367
373
|
|
@@ -415,19 +421,20 @@ class Scheduler:
|
|
415
421
|
)
|
416
422
|
|
417
423
|
# Truncate prompts that are too long
|
418
|
-
if len(req.origin_input_ids)
|
424
|
+
if len(req.origin_input_ids) > self.max_req_input_len:
|
419
425
|
logger.warning(
|
420
426
|
"Request length is longer than the KV cache pool size or "
|
421
427
|
"the max context length. Truncated!!!"
|
422
428
|
)
|
423
429
|
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
|
430
|
+
|
424
431
|
req.sampling_params.max_new_tokens = min(
|
425
432
|
(
|
426
433
|
req.sampling_params.max_new_tokens
|
427
434
|
if req.sampling_params.max_new_tokens is not None
|
428
435
|
else 1 << 30
|
429
436
|
),
|
430
|
-
self.
|
437
|
+
self.max_req_len - len(req.origin_input_ids) - 1,
|
431
438
|
)
|
432
439
|
|
433
440
|
self.waiting_queue.append(req)
|
@@ -575,6 +582,7 @@ class Scheduler:
|
|
575
582
|
else set([])
|
576
583
|
)
|
577
584
|
|
585
|
+
# Get requests from the waiting queue to a new prefill batch
|
578
586
|
for req in self.waiting_queue:
|
579
587
|
if (
|
580
588
|
self.lora_paths
|
@@ -661,12 +669,13 @@ class Scheduler:
|
|
661
669
|
self.req_to_token_pool,
|
662
670
|
self.token_to_kv_pool,
|
663
671
|
self.tree_cache,
|
672
|
+
self.model_config,
|
664
673
|
)
|
665
|
-
new_batch.prepare_for_extend(
|
674
|
+
new_batch.prepare_for_extend()
|
666
675
|
|
667
676
|
# Mixed-style chunked prefill
|
668
677
|
if self.is_mixed_chunk and self.running_batch is not None:
|
669
|
-
self.running_batch.prepare_for_decode()
|
678
|
+
self.running_batch.prepare_for_decode(self.enable_overlap)
|
670
679
|
new_batch.mix_with_running(self.running_batch)
|
671
680
|
new_batch.decoding_reqs = self.running_batch.reqs
|
672
681
|
self.running_batch = None
|
@@ -676,6 +685,7 @@ class Scheduler:
|
|
676
685
|
return new_batch
|
677
686
|
|
678
687
|
def update_running_batch(self):
|
688
|
+
"""Update the current running decoding batch."""
|
679
689
|
global test_retract
|
680
690
|
batch = self.running_batch
|
681
691
|
|
@@ -712,13 +722,14 @@ class Scheduler:
|
|
712
722
|
return
|
713
723
|
|
714
724
|
# Update batch tensors
|
715
|
-
batch.prepare_for_decode()
|
725
|
+
batch.prepare_for_decode(self.enable_overlap)
|
716
726
|
|
717
727
|
def run_batch(self, batch: ScheduleBatch):
|
728
|
+
"""Run a batch."""
|
718
729
|
if self.is_generation:
|
719
730
|
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
|
720
731
|
model_worker_batch = batch.get_model_worker_batch()
|
721
|
-
logits_output, next_token_ids = self.forward_batch_generation(
|
732
|
+
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
722
733
|
model_worker_batch
|
723
734
|
)
|
724
735
|
else:
|
@@ -749,9 +760,12 @@ class Scheduler:
|
|
749
760
|
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
750
761
|
if self.is_generation:
|
751
762
|
logits_output, next_token_ids, bid = result
|
752
|
-
|
753
|
-
|
754
|
-
|
763
|
+
|
764
|
+
if self.enable_overlap:
|
765
|
+
logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
|
766
|
+
else:
|
767
|
+
# Move next_token_ids and logprobs to cpu
|
768
|
+
if batch.return_logprob:
|
755
769
|
logits_output.next_token_logprobs = (
|
756
770
|
logits_output.next_token_logprobs[
|
757
771
|
torch.arange(len(next_token_ids), device=self.device),
|
@@ -764,8 +778,7 @@ class Scheduler:
|
|
764
778
|
logits_output.normalized_prompt_logprobs = (
|
765
779
|
logits_output.normalized_prompt_logprobs.tolist()
|
766
780
|
)
|
767
|
-
|
768
|
-
next_token_ids = self.resolve_next_token_ids(bid, next_token_ids)
|
781
|
+
next_token_ids = next_token_ids.tolist()
|
769
782
|
|
770
783
|
# Check finish conditions
|
771
784
|
logprob_pt = 0
|
@@ -779,7 +792,7 @@ class Scheduler:
|
|
779
792
|
req.check_finished()
|
780
793
|
|
781
794
|
if req.finished():
|
782
|
-
self.cache_finished_req(req)
|
795
|
+
self.tree_cache.cache_finished_req(req)
|
783
796
|
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
784
797
|
self.tree_cache.cache_unfinished_req(req)
|
785
798
|
|
@@ -808,7 +821,7 @@ class Scheduler:
|
|
808
821
|
req.check_finished()
|
809
822
|
|
810
823
|
if req.finished():
|
811
|
-
self.cache_finished_req(req)
|
824
|
+
self.tree_cache.cache_finished_req(req)
|
812
825
|
else:
|
813
826
|
self.tree_cache.cache_unfinished_req(req)
|
814
827
|
|
@@ -818,14 +831,17 @@ class Scheduler:
|
|
818
831
|
logits_output, next_token_ids, bid = result
|
819
832
|
self.num_generated_tokens += len(batch.reqs)
|
820
833
|
|
821
|
-
|
822
|
-
|
823
|
-
next_token_logprobs = logits_output.next_token_logprobs
|
824
|
-
|
825
|
-
|
826
|
-
|
827
|
-
|
828
|
-
|
834
|
+
if self.enable_overlap:
|
835
|
+
logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
|
836
|
+
next_token_logprobs = logits_output.next_token_logprobs
|
837
|
+
else:
|
838
|
+
# Move next_token_ids and logprobs to cpu
|
839
|
+
if batch.return_logprob:
|
840
|
+
next_token_logprobs = logits_output.next_token_logprobs[
|
841
|
+
torch.arange(len(next_token_ids), device=self.device),
|
842
|
+
next_token_ids,
|
843
|
+
].tolist()
|
844
|
+
next_token_ids = next_token_ids.tolist()
|
829
845
|
|
830
846
|
self.token_to_kv_pool.free_group_begin()
|
831
847
|
|
@@ -845,7 +861,7 @@ class Scheduler:
|
|
845
861
|
)
|
846
862
|
|
847
863
|
if req.finished():
|
848
|
-
self.cache_finished_req(req)
|
864
|
+
self.tree_cache.cache_finished_req(req)
|
849
865
|
|
850
866
|
if req.return_logprob:
|
851
867
|
req.output_token_logprobs.append(
|
@@ -936,6 +952,7 @@ class Scheduler:
|
|
936
952
|
return num_input_logprobs
|
937
953
|
|
938
954
|
def stream_output(self, reqs: List[Req]):
|
955
|
+
"""Stream the output to detokenizer."""
|
939
956
|
output_rids = []
|
940
957
|
output_meta_info = []
|
941
958
|
output_finished_reason: List[BaseFinishReason] = []
|
@@ -1033,6 +1050,7 @@ class Scheduler:
|
|
1033
1050
|
)
|
1034
1051
|
|
1035
1052
|
def flush_cache(self):
|
1053
|
+
"""Flush the memory pool and cache."""
|
1036
1054
|
if len(self.waiting_queue) == 0 and (
|
1037
1055
|
self.running_batch is None or len(self.running_batch.reqs) == 0
|
1038
1056
|
):
|
@@ -1069,10 +1087,11 @@ class Scheduler:
|
|
1069
1087
|
for req in self.running_batch.reqs:
|
1070
1088
|
if req.rid == recv_req.rid and not req.finished():
|
1071
1089
|
req.finished_reason = FINISH_ABORT()
|
1072
|
-
self.cache_finished_req(req)
|
1090
|
+
self.tree_cache.cache_finished_req(req)
|
1073
1091
|
break
|
1074
1092
|
|
1075
1093
|
def update_weights(self, recv_req: UpdateWeightReqInput):
|
1094
|
+
"""In-place update of the weights."""
|
1076
1095
|
success, message = self.tp_worker.update_weights(recv_req)
|
1077
1096
|
if success:
|
1078
1097
|
flash_cache_success = self.flush_cache()
|
@@ -1112,7 +1131,7 @@ def run_scheduler_process(
|
|
1112
1131
|
suppress_other_loggers()
|
1113
1132
|
|
1114
1133
|
try:
|
1115
|
-
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank)
|
1134
|
+
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
|
1116
1135
|
pipe_writer.send("ready")
|
1117
1136
|
if server_args.enable_overlap_schedule:
|
1118
1137
|
scheduler.event_loop_overlap()
|
@@ -46,6 +46,8 @@ from sglang.srt.managers.io_struct import (
|
|
46
46
|
EmbeddingReqInput,
|
47
47
|
FlushCacheReq,
|
48
48
|
GenerateReqInput,
|
49
|
+
GetMemPoolSizeReq,
|
50
|
+
GetMemPoolSizeReqOutput,
|
49
51
|
ProfileReq,
|
50
52
|
RewardReqInput,
|
51
53
|
TokenizedEmbeddingReqInput,
|
@@ -122,7 +124,7 @@ class TokenizerManager:
|
|
122
124
|
|
123
125
|
# We want to parallelize the image pre-processing so we create an executor for it
|
124
126
|
self.image_processor = get_image_processor(
|
125
|
-
self.hf_config, server_args, self.processor
|
127
|
+
self.hf_config, server_args, self.processor
|
126
128
|
)
|
127
129
|
else:
|
128
130
|
self.tokenizer = get_tokenizer(
|
@@ -191,8 +193,10 @@ class TokenizerManager:
|
|
191
193
|
sampling_params = self._get_sampling_params(obj.sampling_params)
|
192
194
|
if self.is_generation:
|
193
195
|
image_inputs = await self.image_processor.process_images_async(
|
194
|
-
obj.image_data, obj
|
196
|
+
obj.image_data, input_text or input_ids, obj
|
195
197
|
)
|
198
|
+
if image_inputs and "input_ids" in image_inputs:
|
199
|
+
input_ids = image_inputs["input_ids"]
|
196
200
|
return_logprob = obj.return_logprob
|
197
201
|
logprob_start_len = obj.logprob_start_len
|
198
202
|
top_logprobs_num = obj.top_logprobs_num
|
@@ -217,8 +221,10 @@ class TokenizerManager:
|
|
217
221
|
sampling_params = self._get_sampling_params(obj.sampling_params[index])
|
218
222
|
if self.is_generation:
|
219
223
|
image_inputs = await self.image_processor.process_images_async(
|
220
|
-
obj.image_data[index], obj
|
224
|
+
obj.image_data[index], input_text or input_ids, obj
|
221
225
|
)
|
226
|
+
if image_inputs and "input_ids" in image_inputs:
|
227
|
+
input_ids = image_inputs["input_ids"]
|
222
228
|
return_logprob = obj.return_logprob[index]
|
223
229
|
logprob_start_len = obj.logprob_start_len[index]
|
224
230
|
top_logprobs_num = obj.top_logprobs_num[index]
|
@@ -263,8 +269,10 @@ class TokenizerManager:
|
|
263
269
|
sampling_params = SamplingParams(**obj.sampling_params[0])
|
264
270
|
sampling_params.max_new_tokens = 0
|
265
271
|
image_inputs = await self.image_processor.process_images_async(
|
266
|
-
obj.image_data[0], obj
|
272
|
+
obj.image_data[0], input_text or input_ids, obj
|
267
273
|
)
|
274
|
+
if image_inputs and "input_ids" in image_inputs:
|
275
|
+
input_ids = image_inputs["input_ids"]
|
268
276
|
return_logprob = obj.return_logprob[0]
|
269
277
|
logprob_start_len = obj.logprob_start_len[0]
|
270
278
|
top_logprobs_num = obj.top_logprobs_num[0]
|
@@ -525,6 +533,15 @@ class TokenizerManager:
|
|
525
533
|
req = ProfileReq.STOP_PROFILE
|
526
534
|
self.send_to_scheduler.send_pyobj(req)
|
527
535
|
|
536
|
+
async def get_memory_pool_size(self):
|
537
|
+
if self.to_create_loop:
|
538
|
+
self.create_handle_loop()
|
539
|
+
|
540
|
+
req = GetMemPoolSizeReq()
|
541
|
+
self.send_to_scheduler.send_pyobj(req)
|
542
|
+
self.mem_pool_size = asyncio.Future()
|
543
|
+
return await self.mem_pool_size
|
544
|
+
|
528
545
|
async def update_weights(
|
529
546
|
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
|
530
547
|
):
|
@@ -584,6 +601,9 @@ class TokenizerManager:
|
|
584
601
|
if isinstance(recv_obj, UpdateWeightReqOutput):
|
585
602
|
self.model_update_result.set_result(recv_obj)
|
586
603
|
continue
|
604
|
+
elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
|
605
|
+
self.mem_pool_size.set_result(recv_obj)
|
606
|
+
continue
|
587
607
|
|
588
608
|
assert isinstance(
|
589
609
|
recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -17,16 +17,12 @@ limitations under the License.
|
|
17
17
|
|
18
18
|
import json
|
19
19
|
import logging
|
20
|
-
import
|
21
|
-
import time
|
22
|
-
from queue import Queue
|
23
|
-
|
24
|
-
import torch
|
20
|
+
from typing import Optional
|
25
21
|
|
26
22
|
from sglang.srt.configs.model_config import ModelConfig
|
27
23
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
28
24
|
from sglang.srt.managers.io_struct import UpdateWeightReqInput
|
29
|
-
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
25
|
+
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
30
26
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
31
27
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
32
28
|
from sglang.srt.server_args import ServerArgs
|
@@ -40,9 +36,10 @@ class TpModelWorker:
|
|
40
36
|
|
41
37
|
def __init__(
|
42
38
|
self,
|
39
|
+
server_args: ServerArgs,
|
43
40
|
gpu_id: int,
|
44
41
|
tp_rank: int,
|
45
|
-
|
42
|
+
dp_rank: Optional[int],
|
46
43
|
nccl_port: int,
|
47
44
|
):
|
48
45
|
# Parse args
|
@@ -93,10 +90,14 @@ class TpModelWorker:
|
|
93
90
|
),
|
94
91
|
self.model_runner.req_to_token_pool.size,
|
95
92
|
)
|
96
|
-
self.
|
93
|
+
self.max_req_len = min(
|
97
94
|
self.model_config.context_len - 1,
|
98
95
|
self.max_total_num_tokens - 1,
|
99
96
|
)
|
97
|
+
self.max_req_input_len = self.max_req_len - 5
|
98
|
+
assert (
|
99
|
+
self.max_req_len > 0 and self.max_req_input_len > 0
|
100
|
+
), "Memory pool size is too small"
|
100
101
|
|
101
102
|
# Sync random seed across TP workers
|
102
103
|
self.random_seed = broadcast_pyobj(
|
@@ -106,92 +107,32 @@ class TpModelWorker:
|
|
106
107
|
)[0]
|
107
108
|
set_random_seed(self.random_seed)
|
108
109
|
|
109
|
-
|
110
|
-
self.init_overlap_status()
|
111
|
-
|
112
|
-
def get_token_and_memory_info(self):
|
110
|
+
def get_worker_info(self):
|
113
111
|
return (
|
114
112
|
self.max_total_num_tokens,
|
115
113
|
self.max_prefill_tokens,
|
116
114
|
self.max_running_requests,
|
115
|
+
self.max_req_len,
|
117
116
|
self.max_req_input_len,
|
118
117
|
self.random_seed,
|
118
|
+
self.device,
|
119
|
+
global_server_args_dict,
|
120
|
+
self.model_runner.req_to_token_pool.size,
|
121
|
+
self.model_runner.req_to_token_pool.max_context_len,
|
122
|
+
self.model_runner.token_to_kv_pool.size,
|
119
123
|
)
|
120
124
|
|
121
|
-
def
|
122
|
-
self.
|
123
|
-
|
124
|
-
|
125
|
-
self.
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
self.future_event_map = dict()
|
132
|
-
self.forward_queue = Queue()
|
133
|
-
self.forward_stream = torch.cuda.Stream()
|
134
|
-
self.forward_thread = threading.Thread(
|
135
|
-
target=self.forward_thread_func,
|
125
|
+
def get_pad_input_ids_func(self):
|
126
|
+
return getattr(self.model_runner.model, "pad_input_ids", None)
|
127
|
+
|
128
|
+
def get_tp_cpu_group(self):
|
129
|
+
return self.model_runner.tp_group.cpu_group
|
130
|
+
|
131
|
+
def get_memory_pool(self):
|
132
|
+
return (
|
133
|
+
self.model_runner.req_to_token_pool,
|
134
|
+
self.model_runner.token_to_kv_pool,
|
136
135
|
)
|
137
|
-
self.forward_thread.start()
|
138
|
-
|
139
|
-
def forward_thread_func(self):
|
140
|
-
with torch.cuda.stream(self.forward_stream):
|
141
|
-
self.forward_thread_func_()
|
142
|
-
|
143
|
-
@torch.inference_mode()
|
144
|
-
def forward_thread_func_(self):
|
145
|
-
while True:
|
146
|
-
tic1 = time.time()
|
147
|
-
model_worker_batch, future_logits_output, future_next_token_ids = (
|
148
|
-
self.forward_queue.get()
|
149
|
-
)
|
150
|
-
|
151
|
-
# Resolve future tokens in the input
|
152
|
-
tic2 = time.time()
|
153
|
-
resolved_input_ids = model_worker_batch.input_ids
|
154
|
-
future_mask = resolved_input_ids < 0
|
155
|
-
resolved_input_ids[future_mask] = self.future_token_ids_map[
|
156
|
-
-resolved_input_ids[future_mask]
|
157
|
-
]
|
158
|
-
|
159
|
-
# Run forward
|
160
|
-
logits_output, next_token_ids = self.forward_batch_generation(
|
161
|
-
model_worker_batch
|
162
|
-
)
|
163
|
-
|
164
|
-
# Set future values
|
165
|
-
if model_worker_batch.return_logprob:
|
166
|
-
self.future_logits_output_dict[future_logits_output] = logits_output
|
167
|
-
|
168
|
-
# logger.info(f"set output {future_next_token_ids=}, {next_token_ids=}")
|
169
|
-
self.future_token_ids_map[-future_next_token_ids] = next_token_ids.to(
|
170
|
-
torch.int32
|
171
|
-
)
|
172
|
-
# logger.info("Set event")
|
173
|
-
self.future_token_ids_output[model_worker_batch.bid] = (
|
174
|
-
next_token_ids.tolist()
|
175
|
-
)
|
176
|
-
self.future_event_map[model_worker_batch.bid].set()
|
177
|
-
|
178
|
-
if False:
|
179
|
-
tic3 = time.time()
|
180
|
-
self.acc_time_with_waiting += tic3 - tic1
|
181
|
-
self.acc_time_without_waiting += tic3 - tic2
|
182
|
-
if self.forward_queue.qsize() == 0:
|
183
|
-
logger.info(
|
184
|
-
f"{self.acc_time_with_waiting=:.3f}, {self.acc_time_without_waiting=:.3f}, {self.forward_queue.qsize()=}"
|
185
|
-
)
|
186
|
-
|
187
|
-
def resolve_future_token_ids(self, bid: int):
|
188
|
-
self.future_event_map[bid].wait()
|
189
|
-
ret = self.future_token_ids_output[bid]
|
190
|
-
del self.future_event_map[bid]
|
191
|
-
return ret
|
192
|
-
|
193
|
-
def resolve_future_logits_output(self, future_obj):
|
194
|
-
return self.future_logits_output_dict.pop(future_obj)
|
195
136
|
|
196
137
|
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
|
197
138
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
@@ -205,32 +146,6 @@ class TpModelWorker:
|
|
205
146
|
embeddings = logits_output.embeddings
|
206
147
|
return embeddings
|
207
148
|
|
208
|
-
def forward_batch_generation_non_blocking(
|
209
|
-
self, model_worker_batch: ModelWorkerBatch
|
210
|
-
):
|
211
|
-
# Allocate output future objects
|
212
|
-
future_logits_output = self.future_logits_output_ct
|
213
|
-
self.future_logits_output_ct += 1
|
214
|
-
|
215
|
-
bs = len(model_worker_batch.seq_lens)
|
216
|
-
with torch.cuda.stream(self.forward_stream):
|
217
|
-
future_next_token_ids = -torch.arange(
|
218
|
-
self.future_token_ids_ct + 1,
|
219
|
-
self.future_token_ids_ct + 1 + bs,
|
220
|
-
dtype=torch.int32,
|
221
|
-
device=self.device,
|
222
|
-
)
|
223
|
-
self.future_token_ids_ct = (
|
224
|
-
self.future_token_ids_ct + bs
|
225
|
-
) % self.future_token_ids_limit
|
226
|
-
ret = future_logits_output, future_next_token_ids
|
227
|
-
|
228
|
-
self.future_event_map[model_worker_batch.bid] = threading.Event()
|
229
|
-
self.forward_queue.put(
|
230
|
-
(model_worker_batch.copy(), future_logits_output, future_next_token_ids)
|
231
|
-
)
|
232
|
-
return ret
|
233
|
-
|
234
149
|
def update_weights(self, recv_req: UpdateWeightReqInput):
|
235
150
|
success, message = self.model_runner.update_weights(
|
236
151
|
recv_req.model_path, recv_req.load_format
|