sglang 0.3.4__py3-none-any.whl → 0.3.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. sglang/bench_latency.py +2 -1
  2. sglang/lang/chat_template.py +17 -0
  3. sglang/launch_server_llavavid.py +1 -1
  4. sglang/srt/configs/__init__.py +3 -0
  5. sglang/srt/configs/model_config.py +27 -2
  6. sglang/srt/configs/qwen2vl.py +133 -0
  7. sglang/srt/constrained/fsm_cache.py +10 -3
  8. sglang/srt/conversation.py +27 -0
  9. sglang/srt/hf_transformers_utils.py +16 -1
  10. sglang/srt/layers/attention/__init__.py +16 -5
  11. sglang/srt/layers/attention/double_sparsity_backend.py +22 -6
  12. sglang/srt/layers/attention/flashinfer_backend.py +174 -54
  13. sglang/srt/layers/attention/triton_backend.py +22 -6
  14. sglang/srt/layers/attention/triton_ops/prefill_attention.py +26 -4
  15. sglang/srt/layers/linear.py +89 -63
  16. sglang/srt/layers/logits_processor.py +5 -5
  17. sglang/srt/layers/rotary_embedding.py +112 -0
  18. sglang/srt/layers/sampler.py +51 -39
  19. sglang/srt/lora/lora.py +3 -1
  20. sglang/srt/managers/data_parallel_controller.py +1 -1
  21. sglang/srt/managers/detokenizer_manager.py +4 -0
  22. sglang/srt/managers/image_processor.py +186 -13
  23. sglang/srt/managers/io_struct.py +10 -0
  24. sglang/srt/managers/schedule_batch.py +238 -68
  25. sglang/srt/managers/scheduler.py +69 -50
  26. sglang/srt/managers/tokenizer_manager.py +24 -4
  27. sglang/srt/managers/tp_worker.py +26 -111
  28. sglang/srt/managers/tp_worker_overlap_thread.py +209 -0
  29. sglang/srt/mem_cache/memory_pool.py +56 -10
  30. sglang/srt/mem_cache/radix_cache.py +4 -3
  31. sglang/srt/model_executor/cuda_graph_runner.py +87 -28
  32. sglang/srt/model_executor/forward_batch_info.py +83 -3
  33. sglang/srt/model_executor/model_runner.py +32 -11
  34. sglang/srt/models/chatglm.py +3 -3
  35. sglang/srt/models/deepseek_v2.py +2 -2
  36. sglang/srt/models/mllama.py +1004 -0
  37. sglang/srt/models/qwen2_vl.py +724 -0
  38. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +6 -3
  39. sglang/srt/sampling/sampling_batch_info.py +13 -3
  40. sglang/srt/sampling/sampling_params.py +5 -7
  41. sglang/srt/server.py +12 -0
  42. sglang/srt/server_args.py +10 -0
  43. sglang/srt/utils.py +22 -0
  44. sglang/test/run_eval.py +2 -0
  45. sglang/test/runners.py +20 -1
  46. sglang/test/srt/sampling/penaltylib/utils.py +1 -0
  47. sglang/test/test_utils.py +100 -3
  48. sglang/version.py +1 -1
  49. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/METADATA +17 -18
  50. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/RECORD +53 -48
  51. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/LICENSE +0 -0
  52. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/WHEEL +0 -0
  53. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/top_level.txt +0 -0
@@ -38,6 +38,8 @@ from sglang.srt.managers.io_struct import (
38
38
  BatchEmbeddingOut,
39
39
  BatchTokenIDOut,
40
40
  FlushCacheReq,
41
+ GetMemPoolSizeReq,
42
+ GetMemPoolSizeReqOutput,
41
43
  ProfileReq,
42
44
  TokenizedEmbeddingReqInput,
43
45
  TokenizedGenerateReqInput,
@@ -51,6 +53,7 @@ from sglang.srt.managers.schedule_batch import (
51
53
  ImageInputs,
52
54
  Req,
53
55
  ScheduleBatch,
56
+ global_server_args_dict,
54
57
  )
55
58
  from sglang.srt.managers.schedule_policy import (
56
59
  AddReqResult,
@@ -58,6 +61,7 @@ from sglang.srt.managers.schedule_policy import (
58
61
  SchedulePolicy,
59
62
  )
60
63
  from sglang.srt.managers.tp_worker import TpModelWorker
64
+ from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
61
65
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
62
66
  from sglang.srt.mem_cache.radix_cache import RadixCache
63
67
  from sglang.srt.server_args import PortArgs, ServerArgs
@@ -67,7 +71,6 @@ from sglang.srt.utils import (
67
71
  is_generation_model,
68
72
  is_multimodal_model,
69
73
  kill_parent_process,
70
- pytorch_profile,
71
74
  set_random_seed,
72
75
  suppress_other_loggers,
73
76
  )
@@ -91,6 +94,7 @@ class Scheduler:
91
94
  port_args: PortArgs,
92
95
  gpu_id: int,
93
96
  tp_rank: int,
97
+ dp_rank: Optional[int],
94
98
  ):
95
99
  # Parse args
96
100
  self.server_args = server_args
@@ -100,6 +104,7 @@ class Scheduler:
100
104
  self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
101
105
  self.lora_paths = server_args.lora_paths
102
106
  self.max_loras_per_batch = server_args.max_loras_per_batch
107
+ self.enable_overlap = server_args.enable_overlap_schedule
103
108
 
104
109
  # Init inter-process communication
105
110
  context = zmq.Context(2)
@@ -143,27 +148,37 @@ class Scheduler:
143
148
  )
144
149
 
145
150
  # Launch a tensor parallel worker
146
- self.tp_worker = TpModelWorker(
151
+ if self.enable_overlap:
152
+ TpWorkerClass = TpModelWorkerClient
153
+ else:
154
+ TpWorkerClass = TpModelWorker
155
+
156
+ self.tp_worker = TpWorkerClass(
157
+ server_args=server_args,
147
158
  gpu_id=gpu_id,
148
159
  tp_rank=tp_rank,
149
- server_args=server_args,
160
+ dp_rank=dp_rank,
150
161
  nccl_port=port_args.nccl_port,
151
162
  )
152
- self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
153
- self.device = self.tp_worker.device
154
163
 
155
164
  # Get token and memory info from the model worker
156
165
  (
157
166
  self.max_total_num_tokens,
158
167
  self.max_prefill_tokens,
159
168
  self.max_running_requests,
169
+ self.max_req_len,
160
170
  self.max_req_input_len,
161
171
  self.random_seed,
162
- ) = self.tp_worker.get_token_and_memory_info()
172
+ self.device,
173
+ worker_global_server_args_dict,
174
+ _,
175
+ _,
176
+ _,
177
+ ) = self.tp_worker.get_worker_info()
178
+ self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
179
+ self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
180
+ global_server_args_dict.update(worker_global_server_args_dict)
163
181
  set_random_seed(self.random_seed)
164
- self.pad_input_ids_func = getattr(
165
- self.tp_worker.model_runner.model, "pad_input_ids", None
166
- )
167
182
 
168
183
  # Print debug info
169
184
  logger.info(
@@ -173,9 +188,8 @@ class Scheduler:
173
188
  f"context_len={self.model_config.context_len}"
174
189
  )
175
190
 
176
- # Init cache
177
- self.req_to_token_pool = self.tp_worker.model_runner.req_to_token_pool
178
- self.token_to_kv_pool = self.tp_worker.model_runner.token_to_kv_pool
191
+ # Init memory pool and cache
192
+ self.req_to_token_pool, self.token_to_kv_pool = self.tp_worker.get_memory_pool()
179
193
 
180
194
  if (
181
195
  server_args.chunked_prefill_size is not None
@@ -253,22 +267,9 @@ class Scheduler:
253
267
  with_stack=True,
254
268
  )
255
269
 
256
- # Init states for overlap schedule
257
- if self.server_args.enable_overlap_schedule:
258
- self.forward_batch_generation = (
259
- self.tp_worker.forward_batch_generation_non_blocking
260
- )
261
- self.resolve_next_token_ids = (
262
- lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
263
- )
264
- self.cache_finished_req = self.tree_cache.cache_finished_req
265
- else:
266
- self.forward_batch_generation = self.tp_worker.forward_batch_generation
267
- self.resolve_next_token_ids = lambda bid, x: x.tolist()
268
- self.cache_finished_req = self.tree_cache.cache_finished_req
269
-
270
270
  @torch.inference_mode()
271
271
  def event_loop_normal(self):
272
+ """A normal blocking scheduler loop."""
272
273
  self.last_batch = None
273
274
 
274
275
  while True:
@@ -299,6 +300,7 @@ class Scheduler:
299
300
 
300
301
  @torch.inference_mode()
301
302
  def event_loop_overlap(self):
303
+ """A scheduler loop that overlaps the CPU processing and GPU computation."""
302
304
  result_queue = deque()
303
305
 
304
306
  self.last_batch = None
@@ -362,6 +364,10 @@ class Scheduler:
362
364
  self.start_profile()
363
365
  else:
364
366
  self.stop_profile()
367
+ elif isinstance(recv_req, GetMemPoolSizeReq):
368
+ self.send_to_detokenizer.send_pyobj(
369
+ GetMemPoolSizeReqOutput(self.max_total_num_tokens)
370
+ )
365
371
  else:
366
372
  raise ValueError(f"Invalid request: {recv_req}")
367
373
 
@@ -415,19 +421,20 @@ class Scheduler:
415
421
  )
416
422
 
417
423
  # Truncate prompts that are too long
418
- if len(req.origin_input_ids) >= self.max_req_input_len:
424
+ if len(req.origin_input_ids) > self.max_req_input_len:
419
425
  logger.warning(
420
426
  "Request length is longer than the KV cache pool size or "
421
427
  "the max context length. Truncated!!!"
422
428
  )
423
429
  req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
430
+
424
431
  req.sampling_params.max_new_tokens = min(
425
432
  (
426
433
  req.sampling_params.max_new_tokens
427
434
  if req.sampling_params.max_new_tokens is not None
428
435
  else 1 << 30
429
436
  ),
430
- self.max_req_input_len - 1 - len(req.origin_input_ids),
437
+ self.max_req_len - len(req.origin_input_ids) - 1,
431
438
  )
432
439
 
433
440
  self.waiting_queue.append(req)
@@ -575,6 +582,7 @@ class Scheduler:
575
582
  else set([])
576
583
  )
577
584
 
585
+ # Get requests from the waiting queue to a new prefill batch
578
586
  for req in self.waiting_queue:
579
587
  if (
580
588
  self.lora_paths
@@ -661,12 +669,13 @@ class Scheduler:
661
669
  self.req_to_token_pool,
662
670
  self.token_to_kv_pool,
663
671
  self.tree_cache,
672
+ self.model_config,
664
673
  )
665
- new_batch.prepare_for_extend(self.model_config.vocab_size)
674
+ new_batch.prepare_for_extend()
666
675
 
667
676
  # Mixed-style chunked prefill
668
677
  if self.is_mixed_chunk and self.running_batch is not None:
669
- self.running_batch.prepare_for_decode()
678
+ self.running_batch.prepare_for_decode(self.enable_overlap)
670
679
  new_batch.mix_with_running(self.running_batch)
671
680
  new_batch.decoding_reqs = self.running_batch.reqs
672
681
  self.running_batch = None
@@ -676,6 +685,7 @@ class Scheduler:
676
685
  return new_batch
677
686
 
678
687
  def update_running_batch(self):
688
+ """Update the current running decoding batch."""
679
689
  global test_retract
680
690
  batch = self.running_batch
681
691
 
@@ -712,13 +722,14 @@ class Scheduler:
712
722
  return
713
723
 
714
724
  # Update batch tensors
715
- batch.prepare_for_decode()
725
+ batch.prepare_for_decode(self.enable_overlap)
716
726
 
717
727
  def run_batch(self, batch: ScheduleBatch):
728
+ """Run a batch."""
718
729
  if self.is_generation:
719
730
  if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
720
731
  model_worker_batch = batch.get_model_worker_batch()
721
- logits_output, next_token_ids = self.forward_batch_generation(
732
+ logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
722
733
  model_worker_batch
723
734
  )
724
735
  else:
@@ -749,9 +760,12 @@ class Scheduler:
749
760
  def process_batch_result_prefill(self, batch: ScheduleBatch, result):
750
761
  if self.is_generation:
751
762
  logits_output, next_token_ids, bid = result
752
- if batch.return_logprob:
753
- # Move logprobs to cpu
754
- if logits_output.next_token_logprobs is not None:
763
+
764
+ if self.enable_overlap:
765
+ logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
766
+ else:
767
+ # Move next_token_ids and logprobs to cpu
768
+ if batch.return_logprob:
755
769
  logits_output.next_token_logprobs = (
756
770
  logits_output.next_token_logprobs[
757
771
  torch.arange(len(next_token_ids), device=self.device),
@@ -764,8 +778,7 @@ class Scheduler:
764
778
  logits_output.normalized_prompt_logprobs = (
765
779
  logits_output.normalized_prompt_logprobs.tolist()
766
780
  )
767
-
768
- next_token_ids = self.resolve_next_token_ids(bid, next_token_ids)
781
+ next_token_ids = next_token_ids.tolist()
769
782
 
770
783
  # Check finish conditions
771
784
  logprob_pt = 0
@@ -779,7 +792,7 @@ class Scheduler:
779
792
  req.check_finished()
780
793
 
781
794
  if req.finished():
782
- self.cache_finished_req(req)
795
+ self.tree_cache.cache_finished_req(req)
783
796
  elif not batch.decoding_reqs or req not in batch.decoding_reqs:
784
797
  self.tree_cache.cache_unfinished_req(req)
785
798
 
@@ -808,7 +821,7 @@ class Scheduler:
808
821
  req.check_finished()
809
822
 
810
823
  if req.finished():
811
- self.cache_finished_req(req)
824
+ self.tree_cache.cache_finished_req(req)
812
825
  else:
813
826
  self.tree_cache.cache_unfinished_req(req)
814
827
 
@@ -818,14 +831,17 @@ class Scheduler:
818
831
  logits_output, next_token_ids, bid = result
819
832
  self.num_generated_tokens += len(batch.reqs)
820
833
 
821
- # Move logprobs to cpu
822
- if batch.return_logprob:
823
- next_token_logprobs = logits_output.next_token_logprobs[
824
- torch.arange(len(next_token_ids), device=self.device),
825
- next_token_ids,
826
- ].tolist()
827
-
828
- next_token_ids = self.resolve_next_token_ids(bid, next_token_ids)
834
+ if self.enable_overlap:
835
+ logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
836
+ next_token_logprobs = logits_output.next_token_logprobs
837
+ else:
838
+ # Move next_token_ids and logprobs to cpu
839
+ if batch.return_logprob:
840
+ next_token_logprobs = logits_output.next_token_logprobs[
841
+ torch.arange(len(next_token_ids), device=self.device),
842
+ next_token_ids,
843
+ ].tolist()
844
+ next_token_ids = next_token_ids.tolist()
829
845
 
830
846
  self.token_to_kv_pool.free_group_begin()
831
847
 
@@ -845,7 +861,7 @@ class Scheduler:
845
861
  )
846
862
 
847
863
  if req.finished():
848
- self.cache_finished_req(req)
864
+ self.tree_cache.cache_finished_req(req)
849
865
 
850
866
  if req.return_logprob:
851
867
  req.output_token_logprobs.append(
@@ -936,6 +952,7 @@ class Scheduler:
936
952
  return num_input_logprobs
937
953
 
938
954
  def stream_output(self, reqs: List[Req]):
955
+ """Stream the output to detokenizer."""
939
956
  output_rids = []
940
957
  output_meta_info = []
941
958
  output_finished_reason: List[BaseFinishReason] = []
@@ -1033,6 +1050,7 @@ class Scheduler:
1033
1050
  )
1034
1051
 
1035
1052
  def flush_cache(self):
1053
+ """Flush the memory pool and cache."""
1036
1054
  if len(self.waiting_queue) == 0 and (
1037
1055
  self.running_batch is None or len(self.running_batch.reqs) == 0
1038
1056
  ):
@@ -1069,10 +1087,11 @@ class Scheduler:
1069
1087
  for req in self.running_batch.reqs:
1070
1088
  if req.rid == recv_req.rid and not req.finished():
1071
1089
  req.finished_reason = FINISH_ABORT()
1072
- self.cache_finished_req(req)
1090
+ self.tree_cache.cache_finished_req(req)
1073
1091
  break
1074
1092
 
1075
1093
  def update_weights(self, recv_req: UpdateWeightReqInput):
1094
+ """In-place update of the weights."""
1076
1095
  success, message = self.tp_worker.update_weights(recv_req)
1077
1096
  if success:
1078
1097
  flash_cache_success = self.flush_cache()
@@ -1112,7 +1131,7 @@ def run_scheduler_process(
1112
1131
  suppress_other_loggers()
1113
1132
 
1114
1133
  try:
1115
- scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank)
1134
+ scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
1116
1135
  pipe_writer.send("ready")
1117
1136
  if server_args.enable_overlap_schedule:
1118
1137
  scheduler.event_loop_overlap()
@@ -46,6 +46,8 @@ from sglang.srt.managers.io_struct import (
46
46
  EmbeddingReqInput,
47
47
  FlushCacheReq,
48
48
  GenerateReqInput,
49
+ GetMemPoolSizeReq,
50
+ GetMemPoolSizeReqOutput,
49
51
  ProfileReq,
50
52
  RewardReqInput,
51
53
  TokenizedEmbeddingReqInput,
@@ -122,7 +124,7 @@ class TokenizerManager:
122
124
 
123
125
  # We want to parallelize the image pre-processing so we create an executor for it
124
126
  self.image_processor = get_image_processor(
125
- self.hf_config, server_args, self.processor.image_processor
127
+ self.hf_config, server_args, self.processor
126
128
  )
127
129
  else:
128
130
  self.tokenizer = get_tokenizer(
@@ -191,8 +193,10 @@ class TokenizerManager:
191
193
  sampling_params = self._get_sampling_params(obj.sampling_params)
192
194
  if self.is_generation:
193
195
  image_inputs = await self.image_processor.process_images_async(
194
- obj.image_data, obj
196
+ obj.image_data, input_text or input_ids, obj
195
197
  )
198
+ if image_inputs and "input_ids" in image_inputs:
199
+ input_ids = image_inputs["input_ids"]
196
200
  return_logprob = obj.return_logprob
197
201
  logprob_start_len = obj.logprob_start_len
198
202
  top_logprobs_num = obj.top_logprobs_num
@@ -217,8 +221,10 @@ class TokenizerManager:
217
221
  sampling_params = self._get_sampling_params(obj.sampling_params[index])
218
222
  if self.is_generation:
219
223
  image_inputs = await self.image_processor.process_images_async(
220
- obj.image_data[index], obj
224
+ obj.image_data[index], input_text or input_ids, obj
221
225
  )
226
+ if image_inputs and "input_ids" in image_inputs:
227
+ input_ids = image_inputs["input_ids"]
222
228
  return_logprob = obj.return_logprob[index]
223
229
  logprob_start_len = obj.logprob_start_len[index]
224
230
  top_logprobs_num = obj.top_logprobs_num[index]
@@ -263,8 +269,10 @@ class TokenizerManager:
263
269
  sampling_params = SamplingParams(**obj.sampling_params[0])
264
270
  sampling_params.max_new_tokens = 0
265
271
  image_inputs = await self.image_processor.process_images_async(
266
- obj.image_data[0], obj
272
+ obj.image_data[0], input_text or input_ids, obj
267
273
  )
274
+ if image_inputs and "input_ids" in image_inputs:
275
+ input_ids = image_inputs["input_ids"]
268
276
  return_logprob = obj.return_logprob[0]
269
277
  logprob_start_len = obj.logprob_start_len[0]
270
278
  top_logprobs_num = obj.top_logprobs_num[0]
@@ -525,6 +533,15 @@ class TokenizerManager:
525
533
  req = ProfileReq.STOP_PROFILE
526
534
  self.send_to_scheduler.send_pyobj(req)
527
535
 
536
+ async def get_memory_pool_size(self):
537
+ if self.to_create_loop:
538
+ self.create_handle_loop()
539
+
540
+ req = GetMemPoolSizeReq()
541
+ self.send_to_scheduler.send_pyobj(req)
542
+ self.mem_pool_size = asyncio.Future()
543
+ return await self.mem_pool_size
544
+
528
545
  async def update_weights(
529
546
  self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
530
547
  ):
@@ -584,6 +601,9 @@ class TokenizerManager:
584
601
  if isinstance(recv_obj, UpdateWeightReqOutput):
585
602
  self.model_update_result.set_result(recv_obj)
586
603
  continue
604
+ elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
605
+ self.mem_pool_size.set_result(recv_obj)
606
+ continue
587
607
 
588
608
  assert isinstance(
589
609
  recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
@@ -17,16 +17,12 @@ limitations under the License.
17
17
 
18
18
  import json
19
19
  import logging
20
- import threading
21
- import time
22
- from queue import Queue
23
-
24
- import torch
20
+ from typing import Optional
25
21
 
26
22
  from sglang.srt.configs.model_config import ModelConfig
27
23
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
28
24
  from sglang.srt.managers.io_struct import UpdateWeightReqInput
29
- from sglang.srt.managers.schedule_batch import ModelWorkerBatch
25
+ from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
30
26
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
31
27
  from sglang.srt.model_executor.model_runner import ModelRunner
32
28
  from sglang.srt.server_args import ServerArgs
@@ -40,9 +36,10 @@ class TpModelWorker:
40
36
 
41
37
  def __init__(
42
38
  self,
39
+ server_args: ServerArgs,
43
40
  gpu_id: int,
44
41
  tp_rank: int,
45
- server_args: ServerArgs,
42
+ dp_rank: Optional[int],
46
43
  nccl_port: int,
47
44
  ):
48
45
  # Parse args
@@ -93,10 +90,14 @@ class TpModelWorker:
93
90
  ),
94
91
  self.model_runner.req_to_token_pool.size,
95
92
  )
96
- self.max_req_input_len = min(
93
+ self.max_req_len = min(
97
94
  self.model_config.context_len - 1,
98
95
  self.max_total_num_tokens - 1,
99
96
  )
97
+ self.max_req_input_len = self.max_req_len - 5
98
+ assert (
99
+ self.max_req_len > 0 and self.max_req_input_len > 0
100
+ ), "Memory pool size is too small"
100
101
 
101
102
  # Sync random seed across TP workers
102
103
  self.random_seed = broadcast_pyobj(
@@ -106,92 +107,32 @@ class TpModelWorker:
106
107
  )[0]
107
108
  set_random_seed(self.random_seed)
108
109
 
109
- if server_args.enable_overlap_schedule:
110
- self.init_overlap_status()
111
-
112
- def get_token_and_memory_info(self):
110
+ def get_worker_info(self):
113
111
  return (
114
112
  self.max_total_num_tokens,
115
113
  self.max_prefill_tokens,
116
114
  self.max_running_requests,
115
+ self.max_req_len,
117
116
  self.max_req_input_len,
118
117
  self.random_seed,
118
+ self.device,
119
+ global_server_args_dict,
120
+ self.model_runner.req_to_token_pool.size,
121
+ self.model_runner.req_to_token_pool.max_context_len,
122
+ self.model_runner.token_to_kv_pool.size,
119
123
  )
120
124
 
121
- def init_overlap_status(self):
122
- self.future_logits_output_dict = dict()
123
- self.future_logits_output_ct = 0
124
- self.future_token_ids_ct = 0
125
- self.future_token_ids_map = torch.empty(
126
- (self.max_running_requests * 5,), dtype=torch.int32, device=self.device
127
- )
128
- self.future_token_ids_limit = self.max_running_requests * 3
129
- self.future_token_ids_output = dict()
130
-
131
- self.future_event_map = dict()
132
- self.forward_queue = Queue()
133
- self.forward_stream = torch.cuda.Stream()
134
- self.forward_thread = threading.Thread(
135
- target=self.forward_thread_func,
125
+ def get_pad_input_ids_func(self):
126
+ return getattr(self.model_runner.model, "pad_input_ids", None)
127
+
128
+ def get_tp_cpu_group(self):
129
+ return self.model_runner.tp_group.cpu_group
130
+
131
+ def get_memory_pool(self):
132
+ return (
133
+ self.model_runner.req_to_token_pool,
134
+ self.model_runner.token_to_kv_pool,
136
135
  )
137
- self.forward_thread.start()
138
-
139
- def forward_thread_func(self):
140
- with torch.cuda.stream(self.forward_stream):
141
- self.forward_thread_func_()
142
-
143
- @torch.inference_mode()
144
- def forward_thread_func_(self):
145
- while True:
146
- tic1 = time.time()
147
- model_worker_batch, future_logits_output, future_next_token_ids = (
148
- self.forward_queue.get()
149
- )
150
-
151
- # Resolve future tokens in the input
152
- tic2 = time.time()
153
- resolved_input_ids = model_worker_batch.input_ids
154
- future_mask = resolved_input_ids < 0
155
- resolved_input_ids[future_mask] = self.future_token_ids_map[
156
- -resolved_input_ids[future_mask]
157
- ]
158
-
159
- # Run forward
160
- logits_output, next_token_ids = self.forward_batch_generation(
161
- model_worker_batch
162
- )
163
-
164
- # Set future values
165
- if model_worker_batch.return_logprob:
166
- self.future_logits_output_dict[future_logits_output] = logits_output
167
-
168
- # logger.info(f"set output {future_next_token_ids=}, {next_token_ids=}")
169
- self.future_token_ids_map[-future_next_token_ids] = next_token_ids.to(
170
- torch.int32
171
- )
172
- # logger.info("Set event")
173
- self.future_token_ids_output[model_worker_batch.bid] = (
174
- next_token_ids.tolist()
175
- )
176
- self.future_event_map[model_worker_batch.bid].set()
177
-
178
- if False:
179
- tic3 = time.time()
180
- self.acc_time_with_waiting += tic3 - tic1
181
- self.acc_time_without_waiting += tic3 - tic2
182
- if self.forward_queue.qsize() == 0:
183
- logger.info(
184
- f"{self.acc_time_with_waiting=:.3f}, {self.acc_time_without_waiting=:.3f}, {self.forward_queue.qsize()=}"
185
- )
186
-
187
- def resolve_future_token_ids(self, bid: int):
188
- self.future_event_map[bid].wait()
189
- ret = self.future_token_ids_output[bid]
190
- del self.future_event_map[bid]
191
- return ret
192
-
193
- def resolve_future_logits_output(self, future_obj):
194
- return self.future_logits_output_dict.pop(future_obj)
195
136
 
196
137
  def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
197
138
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
@@ -205,32 +146,6 @@ class TpModelWorker:
205
146
  embeddings = logits_output.embeddings
206
147
  return embeddings
207
148
 
208
- def forward_batch_generation_non_blocking(
209
- self, model_worker_batch: ModelWorkerBatch
210
- ):
211
- # Allocate output future objects
212
- future_logits_output = self.future_logits_output_ct
213
- self.future_logits_output_ct += 1
214
-
215
- bs = len(model_worker_batch.seq_lens)
216
- with torch.cuda.stream(self.forward_stream):
217
- future_next_token_ids = -torch.arange(
218
- self.future_token_ids_ct + 1,
219
- self.future_token_ids_ct + 1 + bs,
220
- dtype=torch.int32,
221
- device=self.device,
222
- )
223
- self.future_token_ids_ct = (
224
- self.future_token_ids_ct + bs
225
- ) % self.future_token_ids_limit
226
- ret = future_logits_output, future_next_token_ids
227
-
228
- self.future_event_map[model_worker_batch.bid] = threading.Event()
229
- self.forward_queue.put(
230
- (model_worker_batch.copy(), future_logits_output, future_next_token_ids)
231
- )
232
- return ret
233
-
234
149
  def update_weights(self, recv_req: UpdateWeightReqInput):
235
150
  success, message = self.model_runner.update_weights(
236
151
  recv_req.model_path, recv_req.load_format