sglang 0.3.5.post2__py3-none-any.whl → 0.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. sglang/bench_latency.py +1 -553
  2. sglang/bench_offline_throughput.py +48 -20
  3. sglang/bench_one_batch.py +474 -0
  4. sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
  5. sglang/bench_serving.py +71 -1
  6. sglang/check_env.py +3 -6
  7. sglang/srt/constrained/outlines_backend.py +15 -2
  8. sglang/srt/constrained/xgrammar_backend.py +22 -14
  9. sglang/srt/layers/activation.py +3 -0
  10. sglang/srt/layers/attention/flashinfer_backend.py +93 -48
  11. sglang/srt/layers/attention/triton_backend.py +9 -7
  12. sglang/srt/layers/custom_op_util.py +26 -0
  13. sglang/srt/layers/fused_moe/fused_moe.py +11 -4
  14. sglang/srt/layers/layernorm.py +4 -0
  15. sglang/srt/layers/logits_processor.py +10 -10
  16. sglang/srt/layers/sampler.py +4 -8
  17. sglang/srt/layers/torchao_utils.py +2 -0
  18. sglang/srt/managers/data_parallel_controller.py +74 -9
  19. sglang/srt/managers/detokenizer_manager.py +1 -0
  20. sglang/srt/managers/io_struct.py +27 -0
  21. sglang/srt/managers/schedule_batch.py +104 -38
  22. sglang/srt/managers/schedule_policy.py +5 -1
  23. sglang/srt/managers/scheduler.py +204 -54
  24. sglang/srt/managers/session_controller.py +62 -0
  25. sglang/srt/managers/tokenizer_manager.py +38 -0
  26. sglang/srt/managers/tp_worker.py +12 -1
  27. sglang/srt/managers/tp_worker_overlap_thread.py +49 -52
  28. sglang/srt/model_executor/cuda_graph_runner.py +43 -6
  29. sglang/srt/model_executor/forward_batch_info.py +109 -15
  30. sglang/srt/model_executor/model_runner.py +99 -43
  31. sglang/srt/model_parallel.py +98 -0
  32. sglang/srt/models/deepseek_v2.py +147 -44
  33. sglang/srt/models/gemma2.py +9 -8
  34. sglang/srt/models/llava.py +1 -1
  35. sglang/srt/models/llavavid.py +1 -1
  36. sglang/srt/models/olmo.py +3 -3
  37. sglang/srt/models/phi3_small.py +447 -0
  38. sglang/srt/models/qwen2_vl.py +13 -6
  39. sglang/srt/models/torch_native_llama.py +94 -78
  40. sglang/srt/openai_api/adapter.py +6 -2
  41. sglang/srt/openai_api/protocol.py +1 -1
  42. sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
  43. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
  44. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
  45. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
  46. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
  47. sglang/srt/sampling/sampling_batch_info.py +58 -57
  48. sglang/srt/sampling/sampling_params.py +1 -1
  49. sglang/srt/server.py +27 -1
  50. sglang/srt/server_args.py +78 -62
  51. sglang/srt/utils.py +71 -52
  52. sglang/test/runners.py +25 -6
  53. sglang/test/srt/sampling/penaltylib/utils.py +23 -21
  54. sglang/test/test_utils.py +30 -19
  55. sglang/version.py +1 -1
  56. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/METADATA +43 -43
  57. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/RECORD +60 -55
  58. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/WHEEL +1 -1
  59. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/LICENSE +0 -0
  60. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@ limitations under the License.
17
17
 
18
18
  import logging
19
19
  import multiprocessing as mp
20
+ import threading
20
21
  from enum import Enum, auto
21
22
 
22
23
  import zmq
@@ -28,6 +29,7 @@ from sglang.srt.managers.io_struct import (
28
29
  from sglang.srt.managers.scheduler import run_scheduler_process
29
30
  from sglang.srt.server_args import PortArgs, ServerArgs
30
31
  from sglang.srt.utils import (
32
+ bind_port,
31
33
  configure_logger,
32
34
  get_zmq_socket,
33
35
  kill_parent_process,
@@ -80,20 +82,62 @@ class DataParallelController:
80
82
 
81
83
  # Start data parallel workers
82
84
  base_gpu_id = 0
83
- self.workers = []
85
+ self.workers = [None] * server_args.dp_size
86
+
87
+ threads = []
88
+ sockets = []
84
89
  for dp_rank in range(server_args.dp_size):
85
90
  tmp_port_args = PortArgs.init_new(server_args)
91
+ tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name
86
92
  tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name
87
93
 
88
- send_to = self.launch_tensor_parallel_group(
89
- server_args,
90
- tmp_port_args,
91
- base_gpu_id,
92
- dp_rank,
94
+ if server_args.enable_dp_attention:
95
+ # Data parallelism resues the tensor parallelism group,
96
+ # so all dp ranks should use the same nccl port.
97
+ tmp_port_args.nccl_port = port_args.nccl_port
98
+ else:
99
+ # This port is checked free in PortArgs.init_new.
100
+ # We hold it first so that the next dp worker gets a different port
101
+ sockets.append(bind_port(tmp_port_args.nccl_port))
102
+
103
+ # Create a thread for each worker
104
+ thread = threading.Thread(
105
+ target=self.launch_worker_func,
106
+ args=(server_args, tmp_port_args, base_gpu_id, dp_rank),
93
107
  )
108
+ threads.append(thread)
109
+ base_gpu_id += 1 if server_args.enable_dp_attention else server_args.tp_size
110
+
111
+ # Free all sockets before starting the threads to launch TP workers
112
+ for sock in sockets:
113
+ sock.close()
114
+
115
+ # Start all threads
116
+ for thread in threads:
117
+ thread.start()
118
+ for thread in threads:
119
+ thread.join()
120
+
121
+ def launch_worker_func(
122
+ self,
123
+ server_args: ServerArgs,
124
+ port_args: PortArgs,
125
+ base_gpu_id: int,
126
+ dp_rank: int,
127
+ ):
128
+ logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.")
94
129
 
95
- self.workers.append(send_to)
96
- base_gpu_id += server_args.tp_size
130
+ launch_func_ = (
131
+ self.launch_tensor_parallel_process
132
+ if server_args.enable_dp_attention
133
+ else self.launch_tensor_parallel_group
134
+ )
135
+ self.workers[dp_rank] = launch_func_(
136
+ server_args,
137
+ port_args,
138
+ base_gpu_id,
139
+ dp_rank,
140
+ )
97
141
 
98
142
  def launch_tensor_parallel_group(
99
143
  self,
@@ -112,7 +156,7 @@ class DataParallelController:
112
156
  )
113
157
  for tp_rank in tp_rank_range:
114
158
  reader, writer = mp.Pipe(duplex=False)
115
- gpu_id = base_gpu_id + tp_rank % tp_size_per_node
159
+ gpu_id = server_args.base_gpu_id + base_gpu_id + tp_rank % tp_size_per_node
116
160
  proc = mp.Process(
117
161
  target=run_scheduler_process,
118
162
  args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
@@ -131,6 +175,27 @@ class DataParallelController:
131
175
 
132
176
  return send_to
133
177
 
178
+ def launch_tensor_parallel_process(
179
+ self,
180
+ server_args: ServerArgs,
181
+ port_args: PortArgs,
182
+ base_gpu_id: int,
183
+ dp_rank: int,
184
+ ):
185
+ reader, writer = mp.Pipe(duplex=False)
186
+ gpu_id = base_gpu_id
187
+ tp_rank = dp_rank
188
+ proc = mp.Process(
189
+ target=run_scheduler_process,
190
+ args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
191
+ )
192
+ proc.start()
193
+ send_to = get_zmq_socket(
194
+ self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
195
+ )
196
+ reader.recv()
197
+ return send_to
198
+
134
199
  def round_robin_scheduler(self, req):
135
200
  self.workers[self.round_robin_counter].send_pyobj(req)
136
201
  self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)
@@ -175,6 +175,7 @@ class DetokenizerManager:
175
175
  output_strs=output_strs,
176
176
  meta_info=recv_obj.meta_info,
177
177
  finished_reason=recv_obj.finished_reason,
178
+ session_ids=recv_obj.session_ids,
178
179
  )
179
180
  )
180
181
 
@@ -56,6 +56,10 @@ class GenerateReqInput:
56
56
  # LoRA related
57
57
  lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
58
58
 
59
+ # Session id info for continual prompting
60
+ session_id: Optional[Union[List[str], str]] = None
61
+ session_rid: Optional[Union[List[str], str]] = None
62
+
59
63
  def normalize_batch_and_arguments(self):
60
64
  if (self.text is None and self.input_ids is None) or (
61
65
  self.text is not None and self.input_ids is not None
@@ -200,6 +204,10 @@ class TokenizedGenerateReqInput:
200
204
  # LoRA related
201
205
  lora_path: Optional[str] = None # None means just use the base model
202
206
 
207
+ # Session id info for continual prompting
208
+ session_id: Optional[int] = None
209
+ session_rid: Optional[str] = None
210
+
203
211
 
204
212
  @dataclass
205
213
  class EmbeddingReqInput:
@@ -293,6 +301,8 @@ class BatchTokenIDOut:
293
301
  meta_info: List[Dict]
294
302
  finished_reason: List[BaseFinishReason]
295
303
  no_stop_trim: List[bool]
304
+ # The updated session unique id
305
+ session_ids: List[str]
296
306
 
297
307
 
298
308
  @dataclass
@@ -305,6 +315,8 @@ class BatchStrOut:
305
315
  meta_info: List[Dict]
306
316
  # The finish reason
307
317
  finished_reason: List[BaseFinishReason]
318
+ # The update session unique id
319
+ session_ids: List[str]
308
320
 
309
321
 
310
322
  @dataclass
@@ -357,3 +369,18 @@ class GetMemPoolSizeReq:
357
369
  @dataclass
358
370
  class GetMemPoolSizeReqOutput:
359
371
  size: int
372
+
373
+
374
+ @dataclass
375
+ class OpenSessionReqInput:
376
+ capacity_of_str_len: int
377
+
378
+
379
+ @dataclass
380
+ class CloseSessionReqInput:
381
+ session_id: str
382
+
383
+
384
+ @dataclass
385
+ class OpenSessionReqOutput:
386
+ session_id: str
@@ -34,6 +34,8 @@ import logging
34
34
  from typing import List, Optional, Tuple, Union
35
35
 
36
36
  import torch
37
+ import triton
38
+ import triton.language as tl
37
39
 
38
40
  from sglang.global_config import global_config
39
41
  from sglang.srt.configs.model_config import ModelConfig
@@ -55,7 +57,8 @@ global_server_args_dict = {
55
57
  "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
56
58
  "disable_mla": ServerArgs.disable_mla,
57
59
  "torchao_config": ServerArgs.torchao_config,
58
- "disable_nan_detection": ServerArgs.disable_nan_detection,
60
+ "enable_nan_detection": ServerArgs.enable_nan_detection,
61
+ "enable_dp_attention": ServerArgs.enable_dp_attention,
59
62
  }
60
63
 
61
64
 
@@ -133,6 +136,7 @@ class ImageInputs:
133
136
  image_embeds: Optional[List[torch.Tensor]] = None
134
137
  aspect_ratio_ids: Optional[List[torch.Tensor]] = None
135
138
  aspect_ratio_mask: Optional[List[torch.Tensor]] = None
139
+
136
140
  # QWen2-VL related
137
141
  image_grid_thws: List[Tuple[int, int, int]] = None
138
142
  mrope_position_delta: Optional[torch.Tensor] = None
@@ -176,6 +180,7 @@ class Req:
176
180
  origin_input_ids: Tuple[int],
177
181
  sampling_params: SamplingParams,
178
182
  lora_path: Optional[str] = None,
183
+ session_id: Optional[str] = None,
179
184
  ):
180
185
  # Input and output info
181
186
  self.rid = rid
@@ -184,11 +189,12 @@ class Req:
184
189
  self.origin_input_ids = origin_input_ids
185
190
  self.output_ids = [] # Each decode stage's output ids
186
191
  self.fill_ids = None # fill_ids = origin_input_ids + output_ids
192
+ self.session_id = session_id
187
193
 
188
194
  self.sampling_params = sampling_params
189
195
  self.lora_path = lora_path
190
196
 
191
- # Memory info
197
+ # Memory pool info
192
198
  self.req_pool_idx = None
193
199
 
194
200
  # Check finish
@@ -425,7 +431,7 @@ bid = 0
425
431
 
426
432
  @dataclasses.dataclass
427
433
  class ScheduleBatch:
428
- """Store all inforamtion of a batch."""
434
+ """Store all inforamtion of a batch on the scheduler."""
429
435
 
430
436
  # Request, memory pool, and cache
431
437
  reqs: List[Req]
@@ -435,9 +441,9 @@ class ScheduleBatch:
435
441
 
436
442
  # For utility
437
443
  model_config: ModelConfig = None
438
-
439
444
  forward_mode: ForwardMode = None
440
445
  sampling_info: SamplingBatchInfo = None
446
+ next_batch_sampling_info: SamplingBatchInfo = None
441
447
 
442
448
  # Batched arguments to model runner
443
449
  input_ids: torch.Tensor = None
@@ -450,6 +456,10 @@ class ScheduleBatch:
450
456
  # The sum of all sequence lengths
451
457
  seq_lens_sum: int = None
452
458
 
459
+ # For DP attention
460
+ global_num_tokens: Optional[List[int]] = None
461
+ can_run_dp_cuda_graph: bool = False
462
+
453
463
  # For processing logprobs
454
464
  return_logprob: bool = False
455
465
  top_logprobs_nums: Optional[List[int]] = None
@@ -502,7 +512,7 @@ class ScheduleBatch:
502
512
  def is_empty(self):
503
513
  return len(self.reqs) == 0
504
514
 
505
- def alloc_req_slots(self, num_reqs):
515
+ def alloc_req_slots(self, num_reqs: int):
506
516
  req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
507
517
  if req_pool_indices is None:
508
518
  raise RuntimeError(
@@ -588,14 +598,14 @@ class ScheduleBatch:
588
598
  )
589
599
 
590
600
  if not decoder_out_cache_loc:
591
- self.out_cache_loc = torch.empty(0, dtype=torch.int32).to(
601
+ self.out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
592
602
  self.device, non_blocking=True
593
603
  )
594
604
  else:
595
605
  self.out_cache_loc = torch.cat(decoder_out_cache_loc)
596
606
 
597
607
  if not encoder_out_cache_loc:
598
- self.encoder_out_cache_loc = torch.empty(0, dtype=torch.int32).to(
608
+ self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
599
609
  self.device, non_blocking=True
600
610
  )
601
611
  else:
@@ -603,7 +613,7 @@ class ScheduleBatch:
603
613
 
604
614
  assert len(self.out_cache_loc) == self.extend_num_tokens
605
615
 
606
- def prepare_for_extend(self):
616
+ def prepare_for_extend(self, enable_overlap_schedule: bool = False):
607
617
  self.forward_mode = ForwardMode.EXTEND
608
618
 
609
619
  bs = len(self.reqs)
@@ -611,12 +621,12 @@ class ScheduleBatch:
611
621
  input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
612
622
  extend_num_tokens = sum(len(ids) for ids in input_ids)
613
623
  seq_lens = []
624
+ pre_lens = []
614
625
 
615
626
  # Allocate memory
616
627
  req_pool_indices = self.alloc_req_slots(bs)
617
628
  out_cache_loc = self.alloc_token_slots(extend_num_tokens)
618
629
 
619
- pt = 0
620
630
  for i, req in enumerate(reqs):
621
631
  already_computed = (
622
632
  req.extend_logprob_start_len + 1 + req.cached_tokens
@@ -634,10 +644,6 @@ class ScheduleBatch:
634
644
  self.req_to_token_pool.write(
635
645
  (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
636
646
  )
637
- self.req_to_token_pool.write(
638
- (req.req_pool_idx, slice(pre_len, seq_len)),
639
- out_cache_loc[pt : pt + req.extend_input_len],
640
- )
641
647
 
642
648
  # Compute the relative logprob_start_len in an extend batch
643
649
  if req.logprob_start_len >= pre_len:
@@ -648,8 +654,8 @@ class ScheduleBatch:
648
654
  extend_logprob_start_len = req.extend_input_len - 1
649
655
 
650
656
  req.extend_logprob_start_len = extend_logprob_start_len
651
- pt += req.extend_input_len
652
657
  req.is_retracted = False
658
+ pre_lens.append(pre_len)
653
659
 
654
660
  # Set fields
655
661
  self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
@@ -661,7 +667,6 @@ class ScheduleBatch:
661
667
  self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
662
668
  self.device, non_blocking=True
663
669
  )
664
-
665
670
  self.out_cache_loc = out_cache_loc
666
671
 
667
672
  self.seq_lens_sum = sum(seq_lens)
@@ -672,13 +677,37 @@ class ScheduleBatch:
672
677
  self.extend_lens = [r.extend_input_len for r in reqs]
673
678
  self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
674
679
 
680
+ # Write to req_to_token_pool
681
+ pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to(
682
+ self.device, non_blocking=True
683
+ )
684
+ extend_lens = torch.tensor(self.extend_lens, dtype=torch.int32).to(
685
+ self.device, non_blocking=True
686
+ )
687
+ write_req_to_token_pool_triton[(bs,)](
688
+ self.req_to_token_pool.req_to_token,
689
+ self.req_pool_indices,
690
+ pre_lens,
691
+ self.seq_lens,
692
+ extend_lens,
693
+ self.out_cache_loc,
694
+ self.req_to_token_pool.req_to_token.shape[1],
695
+ )
696
+ # The triton kernel is equivalent to the following python code.
697
+ # self.req_to_token_pool.write(
698
+ # (req.req_pool_idx, slice(pre_len, seq_len)),
699
+ # out_cache_loc[pt : pt + req.extend_input_len],
700
+ # )
701
+ # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
702
+
675
703
  if self.model_config.is_encoder_decoder:
676
704
  self.prepare_encoder_info_extend(input_ids, seq_lens)
677
705
 
706
+ # Build sampling info
678
707
  self.sampling_info = SamplingBatchInfo.from_schedule_batch(
679
708
  self,
680
709
  self.model_config.vocab_size,
681
- global_server_args_dict["disable_penalizer"],
710
+ enable_overlap_schedule=enable_overlap_schedule,
682
711
  )
683
712
 
684
713
  def mix_with_running(self, running_batch: "ScheduleBatch"):
@@ -720,6 +749,7 @@ class ScheduleBatch:
720
749
  return False
721
750
 
722
751
  def retract_decode(self):
752
+ """Retract the decoding requests when there is not enough memory."""
723
753
  sorted_indices = [i for i in range(len(self.reqs))]
724
754
 
725
755
  # TODO(lsyin): improve retraction policy for radix cache
@@ -858,15 +888,21 @@ class ScheduleBatch:
858
888
  # Reset the encoder cached status
859
889
  self.encoder_cached = [True] * len(self.reqs)
860
890
 
891
+ def prepare_for_idle(self):
892
+ self.forward_mode = ForwardMode.IDLE
893
+ self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
894
+ self.seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
895
+ self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
896
+ self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
897
+ self.seq_lens_sum = 0
898
+ self.extend_num_tokens = 0
899
+
861
900
  def prepare_for_decode(self, enable_overlap: bool = False):
862
901
  self.forward_mode = ForwardMode.DECODE
863
902
 
864
903
  self.input_ids = self.output_ids
865
904
  self.output_ids = None
866
- if self.sampling_info.penalizer_orchestrator:
867
- self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
868
- self.input_ids
869
- )
905
+ self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(self.input_ids)
870
906
 
871
907
  # Alloc mem
872
908
  bs = len(self.reqs)
@@ -969,17 +1005,18 @@ class ScheduleBatch:
969
1005
  self.has_grammar = self.has_grammar or other.has_grammar
970
1006
 
971
1007
  def get_model_worker_batch(self):
972
- if self.forward_mode.is_decode():
1008
+ if self.forward_mode.is_decode() or self.forward_mode.is_idle():
973
1009
  extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
974
1010
  else:
975
1011
  extend_seq_lens = self.extend_lens
976
1012
  extend_prefix_lens = self.prefix_lens
977
1013
  extend_logprob_start_lens = self.extend_logprob_start_lens
978
1014
 
979
- if self.has_grammar:
980
- self.sampling_info.grammars = [req.grammar for req in self.reqs]
981
- else:
982
- self.sampling_info.grammars = None
1015
+ if self.sampling_info:
1016
+ if self.has_grammar:
1017
+ self.sampling_info.grammars = [req.grammar for req in self.reqs]
1018
+ else:
1019
+ self.sampling_info.grammars = None
983
1020
 
984
1021
  global bid
985
1022
  bid += 1
@@ -995,6 +1032,8 @@ class ScheduleBatch:
995
1032
  req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
996
1033
  return_logprob=self.return_logprob,
997
1034
  top_logprobs_nums=self.top_logprobs_nums,
1035
+ global_num_tokens=self.global_num_tokens,
1036
+ can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
998
1037
  extend_num_tokens=self.extend_num_tokens,
999
1038
  extend_seq_lens=extend_seq_lens,
1000
1039
  extend_prefix_lens=extend_prefix_lens,
@@ -1051,6 +1090,10 @@ class ModelWorkerBatch:
1051
1090
  return_logprob: bool
1052
1091
  top_logprobs_nums: Optional[List[int]]
1053
1092
 
1093
+ # For DP attention
1094
+ global_num_tokens: Optional[List[int]]
1095
+ can_run_dp_cuda_graph: bool
1096
+
1054
1097
  # For extend
1055
1098
  extend_num_tokens: Optional[int]
1056
1099
  extend_seq_lens: Optional[List[int]]
@@ -1072,16 +1115,39 @@ class ModelWorkerBatch:
1072
1115
  # Sampling info
1073
1116
  sampling_info: SamplingBatchInfo
1074
1117
 
1075
- def copy(self):
1076
- return dataclasses.replace(self, sampling_info=self.sampling_info.copy())
1077
-
1078
- def to(self, device: str):
1079
- self.input_ids = self.input_ids.to(device, non_blocking=True)
1080
- self.req_pool_indices = self.req_pool_indices.to(device, non_blocking=True)
1081
- self.seq_lens = self.seq_lens.to(device, non_blocking=True)
1082
- self.out_cache_loc = self.out_cache_loc.to(device, non_blocking=True)
1083
- self.req_to_token_pool_records = [
1084
- (x, y.to(device, non_blocking=True))
1085
- for x, y in self.req_to_token_pool_records
1086
- ]
1087
- self.sampling_info.to(device)
1118
+
1119
+ @triton.jit
1120
+ def write_req_to_token_pool_triton(
1121
+ req_to_token_ptr, # [max_batch, max_context_len]
1122
+ req_pool_indices,
1123
+ pre_lens,
1124
+ seq_lens,
1125
+ extend_lens,
1126
+ out_cache_loc,
1127
+ req_to_token_ptr_stride: tl.constexpr,
1128
+ ):
1129
+ BLOCK_SIZE: tl.constexpr = 512
1130
+ pid = tl.program_id(0)
1131
+
1132
+ req_pool_index = tl.load(req_pool_indices + pid)
1133
+ pre_len = tl.load(pre_lens + pid)
1134
+ seq_len = tl.load(seq_lens + pid)
1135
+
1136
+ # TODO: optimize this?
1137
+ cumsum_start = 0
1138
+ for i in range(pid):
1139
+ cumsum_start += tl.load(extend_lens + i)
1140
+
1141
+ num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
1142
+ for i in range(num_loop):
1143
+ offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
1144
+ mask = offset < (seq_len - pre_len)
1145
+ value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
1146
+ tl.store(
1147
+ req_to_token_ptr
1148
+ + req_pool_index * req_to_token_ptr_stride
1149
+ + offset
1150
+ + pre_len,
1151
+ value,
1152
+ mask=mask,
1153
+ )
@@ -302,7 +302,11 @@ class PrefillAdder:
302
302
  if (
303
303
  self.rem_chunk_tokens is None
304
304
  or input_tokens <= self.rem_chunk_tokens
305
- or (req.return_logprob and req.normalized_prompt_logprob is None)
305
+ or (
306
+ req.return_logprob
307
+ and req.normalized_prompt_logprob is None
308
+ and req.logprob_start_len != len(req.origin_input_ids) - 1
309
+ )
306
310
  ):
307
311
  # Non-chunked prefill
308
312
  self.can_run_list.append(req)